diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index aa92df1d6a..2d52a545cf 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -35,10 +35,14 @@ get_dataframes_and_queues, ) from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available -from onedal.utils._array_api import _get_sycl_namespace from onedal.utils._dpep_helpers import dpctl_available, dpnp_available from sklearnex import config_context -from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES +from sklearnex.tests.utils import ( + PATCHED_FUNCTIONS, + PATCHED_MODELS, + SPECIAL_INSTANCES, + DummyEstimator, +) from sklearnex.utils._array_api import get_namespace if dpctl_available: @@ -131,41 +135,6 @@ def gen_functions(functions): ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray} -if _is_dpc_backend: - - from sklearn.utils.validation import check_is_fitted - - from onedal.datatypes import from_table, to_table - - class DummyEstimatorWithTableConversions(BaseEstimator): - - def fit(self, X, y=None): - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - y_table = to_table(y) - # The presence of the fitted attributes (ending with a trailing - # underscore) is required for the correct check. The cleanup of - # the memory will occur at the estimator instance deletion. - self.x_attr_ = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - self.y_attr_ = from_table( - y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return self - - def predict(self, X): - # Checks if the estimator is fitted by verifying the presence of - # fitted attributes (ending with a trailing underscore). - check_is_fitted(self) - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - returned_X = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return returned_X - - def gen_clsf_data(n_samples, n_features, dtype=None): data, label = make_classification( n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777 @@ -369,7 +338,7 @@ def test_table_conversions_memory_leaks(dataframe, queue, order, data_shape, dty pytest.skip("SYCL device memory leak check requires the level zero sysman") _kfold_function_template( - DummyEstimatorWithTableConversions, + DummyEstimator, dataframe, data_shape, queue, diff --git a/sklearnex/tests/utils/__init__.py b/sklearnex/tests/utils/__init__.py index 60ca67fa37..db728fe913 100644 --- a/sklearnex/tests/utils/__init__.py +++ b/sklearnex/tests/utils/__init__.py @@ -21,6 +21,7 @@ SPECIAL_INSTANCES, UNPATCHED_FUNCTIONS, UNPATCHED_MODELS, + DummyEstimator, _get_processor_info, call_method, gen_dataset, @@ -39,6 +40,7 @@ "gen_models_info", "gen_dataset", "sklearn_clone_dict", + "DummyEstimator", ] _IS_INTEL = "GenuineIntel" in _get_processor_info() diff --git a/sklearnex/tests/utils/base.py b/sklearnex/tests/utils/base.py index 1949519585..33d3804b8f 100755 --- a/sklearnex/tests/utils/base.py +++ b/sklearnex/tests/utils/base.py @@ -32,8 +32,11 @@ ) from sklearn.datasets import load_diabetes, load_iris from sklearn.neighbors._base import KNeighborsMixin +from sklearn.utils.validation import check_is_fitted +from onedal.datatypes import from_table, to_table from onedal.tests.utils._dataframes_support import _convert_to_dataframe +from onedal.utils._array_api import _get_sycl_namespace from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics from sklearnex.linear_model import LogisticRegression @@ -369,3 +372,44 @@ def _get_processor_info(): ) return proc + + +class DummyEstimator(BaseEstimator): + + def fit(self, X, y=None): + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + y_table = to_table(y) + # The presence of the fitted attributes (ending with a trailing + # underscore) is required for the correct check. The cleanup of + # the memory will occur at the estimator instance deletion. + if sua_iface: + self.x_attr_ = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + self.y_attr_ = from_table( + y_table, + sua_iface=sua_iface, + sycl_queue=X.sycl_queue if y is None else y.sycl_queue, + xp=xp, + ) + else: + self.x_attr = from_table(X_table) + self.y_attr = from_table(y_table) + + return self + + def predict(self, X): + # Checks if the estimator is fitted by verifying the presence of + # fitted attributes (ending with a trailing underscore). + check_is_fitted(self) + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + if sua_iface: + returned_X = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + returned_X = from_table(X_table) + + return returned_X diff --git a/sklearnex/utils/__init__.py b/sklearnex/utils/__init__.py index 4c3fe21154..686e089adf 100755 --- a/sklearnex/utils/__init__.py +++ b/sklearnex/utils/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # =============================================================================== -from .validation import _assert_all_finite +from .validation import assert_all_finite -__all__ = ["_assert_all_finite"] +__all__ = ["assert_all_finite"] diff --git a/sklearnex/utils/tests/test_finite.py b/sklearnex/utils/tests/test_finite.py deleted file mode 100644 index 7d83667699..0000000000 --- a/sklearnex/utils/tests/test_finite.py +++ /dev/null @@ -1,89 +0,0 @@ -# ============================================================================== -# Copyright 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import time - -import numpy as np -import numpy.random as rand -import pytest -from numpy.testing import assert_raises - -from sklearnex.utils import _assert_all_finite - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -def test_sum_infinite_actually_finite(dtype, shape, allow_nan): - X = np.empty(shape, dtype=dtype) - X.fill(np.finfo(dtype).max) - _assert_all_finite(X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 65539, # 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_location(dtype, shape, allow_nan, check, seed): - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X.reshape((-1,))[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_shape_and_location(dtype, allow_nan, check, seed): - lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) diff --git a/sklearnex/utils/tests/test_validation.py b/sklearnex/utils/tests/test_validation.py new file mode 100644 index 0000000000..37d0a6df6e --- /dev/null +++ b/sklearnex/utils/tests/test_validation.py @@ -0,0 +1,240 @@ +# ============================================================================== +# Copyright contributors to the oneDAL project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time + +import numpy as np +import numpy.random as rand +import pytest + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) +from sklearnex import config_context +from sklearnex.tests.utils import DummyEstimator, gen_dataset +from sklearnex.utils.validation import _check_sample_weight, validate_data + +# array_api support starts in sklearn 1.2, and array_api_strict conformance starts in sklearn 1.3 +_dataframes_supported = ( + "numpy,pandas" + + (",dpctl" if sklearn_check_version("1.2") else "") + + (",array_api" if sklearn_check_version("1.3") else "") +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [2**16 + 3], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +def test_sum_infinite_actually_finite(dtype, shape, ensure_all_finite): + est = DummyEstimator() + X = np.empty(shape, dtype=dtype) + X.fill(np.finfo(dtype).max) + X = np.atleast_2d(X) + X_array = validate_data(est, X, ensure_all_finite=ensure_all_finite) + assert type(X_array) == type(X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [2**16 + 3], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_location( + dataframe, queue, dtype, shape, ensure_all_finite, check, seed +): + est = DummyEstimator() + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X.reshape((-1,))[loc] = float(check) + + # column heavy pandas inputs are very slow in sklearn's check_array even without + # the finite check, just transpose inputs to guarantee fast processing in tests + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}" + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_shape_and_location( + dataframe, queue, dtype, ensure_all_finite, check, seed +): + est = DummyEstimator() + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X[loc] = float(check) + + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}." + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test__check_sample_weight_random_shape_and_location( + dataframe, queue, dtype, check, seed +): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + shape = (rand.randint(lb, ub), 2) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + sample_weight = rand.uniform(high=np.finfo(dtype).max, size=shape[0]).astype(dtype) + + if check: + loc = rand.randint(0, shape[0] - 1) + sample_weight[loc] = float(check) + + X = _convert_to_dataframe( + X, + target_df=dataframe, + sycl_queue=queue, + ) + sample_weight = _convert_to_dataframe( + sample_weight, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + if check is None: + X_out = _check_sample_weight(sample_weight, X) + if dispatch: + assert type(X_out) == type(X) + else: + assert isinstance(X_out, np.ndarray) + else: + msg_err = "Input sample_weight contains [NaN|infinity]" + with pytest.raises(ValueError, match=msg_err): + X_out = _check_sample_weight(sample_weight, X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_output(dtype, dataframe, queue): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + est = DummyEstimator() + X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0] + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + X_out, y_out = validate_data(est, X, y) + # check sklearn validate_data operations work underneath + X_array = validate_data(est, X, reset=False) + + for orig, first, second in ((X, X_out, X_array), (y, y_out, None)): + if dispatch: + assert type(orig) == type( + first + ), f"validate_data converted {type(orig)} to {type(first)}" + if second is not None: + assert type(orig) == type( + second + ), f"from_array converted {type(orig)} to {type(second)}" + else: + # array_api_strict from sklearn < 1.2 and pandas will convert to numpy arrays + assert isinstance(first, np.ndarray) + assert second is None or isinstance(second, np.ndarray) diff --git a/sklearnex/utils/validation.py b/sklearnex/utils/validation.py index b2d1898643..4d12602d74 100755 --- a/sklearnex/utils/validation.py +++ b/sklearnex/utils/validation.py @@ -14,4 +14,195 @@ # limitations under the License. # =============================================================================== -from daal4py.sklearn.utils.validation import _assert_all_finite +import numbers + +import scipy.sparse as sp +from sklearn.utils.validation import _assert_all_finite as _sklearn_assert_all_finite +from sklearn.utils.validation import _num_samples, check_array, check_non_negative + +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version + +from ._array_api import get_namespace + +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data as _sklearn_validate_data + + _finite_keyword = "ensure_all_finite" + +else: + from sklearn.base import BaseEstimator + + _sklearn_validate_data = BaseEstimator._validate_data + _finite_keyword = "force_all_finite" + + +if daal_check_version((2024, "P", 700)): + from onedal.utils.validation import _assert_all_finite as _onedal_assert_all_finite + + def _onedal_supported_format(X, xp): + # array_api does not have a `strides` or `flags` attribute for testing memory + # order. When dlpack support is brought in for oneDAL, general support for + # array_api can be enabled and the hasattr check can be removed. + # _onedal_supported_format is therefore conservative in verifying attributes and + # does not support array_api. This will block onedal_assert_all_finite from being + # used for array_api inputs but will allow dpnp ndarrays and dpctl tensors. + # only check contiguous arrays to prevent unnecessary copying of data, even if + # non-contiguous arrays can now be converted to oneDAL tables. + return ( + X.dtype in [xp.float32, xp.float64] + and hasattr(X, "flags") + and (X.flags["C_CONTIGUOUS"] or X.flags["F_CONTIGUOUS"]) + ) + +else: + from daal4py.utils.validation import _assert_all_finite as _onedal_assert_all_finite + from onedal.utils._array_api import _is_numpy_namespace + + def _onedal_supported_format(X, xp): + # daal4py _assert_all_finite only supports numpy namespaces, use internally- + # defined check to validate inputs, otherwise offload to sklearn + return X.dtype in [xp.float32, xp.float64] and _is_numpy_namespace(xp) + + +def _sklearnex_assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + # size check is an initial match to daal4py for performance reasons, can be + # optimized later + xp, _ = get_namespace(X) + if X.size < 32768 or not _onedal_supported_format(X, xp): + if sklearn_check_version("1.1"): + _sklearn_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + else: + _sklearn_assert_all_finite(X, allow_nan=allow_nan) + else: + _onedal_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + + +def assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + _sklearnex_assert_all_finite( + X.data if sp.issparse(X) else X, + allow_nan=allow_nan, + input_name=input_name, + ) + + +def validate_data( + _estimator, + /, + X="no_validation", + y="no_validation", + **kwargs, +): + # force finite check to not occur in sklearn, default is True + # `ensure_all_finite` is the most up-to-date keyword name in sklearn + # _finite_keyword provides backward compatability for `force_all_finite` + ensure_all_finite = kwargs.pop("ensure_all_finite", True) + kwargs[_finite_keyword] = False + + out = _sklearn_validate_data( + _estimator, + X=X, + y=y, + **kwargs, + ) + + check_x = not isinstance(X, str) or X != "no_validation" + check_y = not (y is None or isinstance(y, str) and y == "no_validation") + + if ensure_all_finite: + # run local finite check + allow_nan = ensure_all_finite == "allow-nan" + # the return object from validate_data can be a single + + # element (either x or y) or both (as a tuple). An iterator along with + # check_x and check_y can go through the output properly without + # stacking layers of if statements to make sure the proper input_name + # is used + arg = iter(out if isinstance(out, tuple) else (out,)) + if check_x: + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X") + if check_y: + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y") + + if check_y and "dtype" in kwargs: + # validate_data does not do full dtype conversions, as it uses check_X_y + # oneDAL can make tables from [int32, float32, float64], requiring + # a dtype check and conversion. This will query the array_namespace and + # convert y as necessary. This is important especially for regressors. + dtype = kwargs["dtype"] + if not isinstance(dtype, (tuple, list)): + dtype = tuple(dtype) + + outx, outy = out if check_x else (None, out) + if outy.dtype not in dtype: + yp, _ = get_namespace(outy) + # use asarray rather than astype because of numpy support + outy = yp.asarray(outy, dtype=dtype[0]) + out = (outx, outy) if check_x else outy + + return out + + +def _check_sample_weight( + sample_weight, X, dtype=None, copy=False, ensure_non_negative=False +): + + n_samples = _num_samples(X) + xp, _ = get_namespace(X) + + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = xp.float64 + + if sample_weight is None: + if hasattr(X, "device"): + sample_weight = xp.ones(n_samples, dtype=dtype, device=X.device) + else: + sample_weight = xp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, numbers.Number): + if hasattr(X, "device"): + sample_weight = xp.full( + n_samples, sample_weight, dtype=dtype, device=X.device + ) + else: + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) + else: + if dtype is None: + dtype = [xp.float64, xp.float32] + + params = { + "accept_sparse": False, + "ensure_2d": False, + "dtype": dtype, + "order": "C", + "copy": copy, + _finite_keyword: False, + } + if sklearn_check_version("1.1"): + params["input_name"] = "sample_weight" + + sample_weight = check_array(sample_weight, **params) + assert_all_finite(sample_weight, input_name="sample_weight") + + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape != (n_samples,): + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if ensure_non_negative: + check_non_negative(sample_weight, "`sample_weight`") + + return sample_weight