Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] add oneDAL finiteness_checker implementation to onedal #2126

Merged
merged 79 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
32fe269
add finiteness_checker pybind11 bindings
icfaust Oct 23, 2024
cdbf1b5
added finiteness checker
icfaust Oct 23, 2024
62674a2
Update finiteness_checker.cpp
icfaust Oct 23, 2024
c75c23b
Update finiteness_checker.cpp
icfaust Oct 23, 2024
6a20938
Update finiteness_checker.cpp
icfaust Oct 23, 2024
382d7a1
Update finiteness_checker.cpp
icfaust Oct 23, 2024
c8ffd9c
Update finiteness_checker.cpp
icfaust Oct 23, 2024
9aa13d5
Update finiteness_checker.cpp
icfaust Oct 23, 2024
84e15d5
Rename finiteness_checker.cpp to finiteness_checker.cpp
icfaust Oct 23, 2024
63073c6
Update finiteness_checker.cpp
icfaust Oct 24, 2024
d915da5
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Oct 28, 2024
3dddf2d
add next step
icfaust Oct 31, 2024
1e1213e
follow conventions
icfaust Oct 31, 2024
0531713
make xtable explicit
icfaust Oct 31, 2024
e831167
remove comment
icfaust Oct 31, 2024
d6eb1d0
Update validation.py
icfaust Oct 31, 2024
fb30d6e
Update __init__.py
icfaust Nov 1, 2024
63a18c2
Update validation.py
icfaust Nov 1, 2024
76c0856
Update __init__.py
icfaust Nov 1, 2024
7deb2bb
Update __init__.py
icfaust Nov 1, 2024
ed46b29
Update validation.py
icfaust Nov 1, 2024
67d6273
Update _data_conversion.py
icfaust Nov 1, 2024
054f0a1
Merge branch 'main' into dev/new_assert_all_fininte
icfaust Nov 1, 2024
8abead9
Update _data_conversion.py
icfaust Nov 1, 2024
47d0f8b
Update policy_common.cpp
icfaust Nov 1, 2024
e48c2bd
Update policy_common.cpp
icfaust Nov 1, 2024
c6751c4
Update _policy.py
icfaust Nov 1, 2024
f3e4a3a
Update policy_common.cpp
icfaust Nov 2, 2024
39cdb5f
Rename finiteness_checker.cpp to finiteness_checker.cpp
icfaust Nov 2, 2024
0f39613
Create finiteness_checker.py
icfaust Nov 2, 2024
b42cfe3
Update validation.py
icfaust Nov 2, 2024
0ed615e
Update __init__.py
icfaust Nov 2, 2024
f101aff
attempt at fixing circular imports again
icfaust Nov 2, 2024
24c0e94
fix isort
icfaust Nov 2, 2024
3f96166
remove __init__ changes
icfaust Nov 2, 2024
d985053
last move
icfaust Nov 2, 2024
90ec48b
Update policy_common.cpp
icfaust Nov 2, 2024
8c2c854
Update policy_common.cpp
icfaust Nov 2, 2024
6fa38d7
Update policy_common.cpp
icfaust Nov 2, 2024
9c1ca9c
Update policy_common.cpp
icfaust Nov 2, 2024
4b67dbd
Update validation.py
icfaust Nov 2, 2024
fa59a3c
add testing
icfaust Nov 2, 2024
3330b33
isort
icfaust Nov 2, 2024
4895940
attempt to fix module error
icfaust Nov 2, 2024
0c6dd5d
add fptype
icfaust Nov 2, 2024
e2182fa
fix typo
icfaust Nov 2, 2024
982ef2c
Update validation.py
icfaust Nov 2, 2024
2fb52a8
remove sua_ifcae from to_table
icfaust Nov 3, 2024
28dc267
isort and black
icfaust Nov 3, 2024
2f85fd4
Update test_memory_usage.py
icfaust Nov 3, 2024
8659248
format
icfaust Nov 3, 2024
3827d6f
Update _data_conversion.py
icfaust Nov 3, 2024
55fa7d2
Update _data_conversion.py
icfaust Nov 3, 2024
175cd78
Update test_validation.py
icfaust Nov 3, 2024
7016ad0
remove unnecessary code
icfaust Nov 3, 2024
1a01859
Merge branch 'main' into dev/new_assert_all_fininte
icfaust Nov 18, 2024
2fbcdd9
merge master
icfaust Nov 18, 2024
fb7375f
make reviewer changes
icfaust Nov 19, 2024
30816bf
make dtype check change
icfaust Nov 19, 2024
abb3b16
add sparse testing
icfaust Nov 19, 2024
97aef73
try again
icfaust Nov 19, 2024
6e29651
try again
icfaust Nov 19, 2024
59363a8
try again
icfaust Nov 19, 2024
61da628
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Nov 20, 2024
e3facab
Update onedal/utils/tests/test_validation.py
icfaust Nov 20, 2024
5bb54a5
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Nov 21, 2024
e8d8c71
formatting
icfaust Nov 21, 2024
1e09b11
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Nov 21, 2024
afc76b8
formatting again
icfaust Nov 21, 2024
edf0350
Merge branch 'dev/new_assert_all_fininte' of https://github.com/icfau…
icfaust Nov 21, 2024
4efad2c
add _check_sample_weight
icfaust Nov 22, 2024
63e2fa8
Revert "add _check_sample_weight"
icfaust Nov 22, 2024
48cafbc
Update test_validation.py
icfaust Nov 25, 2024
085f8a7
Update validation.py
icfaust Nov 25, 2024
cdb11f2
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Nov 25, 2024
b539d23
make changes
icfaust Nov 27, 2024
5549f99
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Nov 27, 2024
61ca3db
Merge branch 'intel:main' into dev/new_assert_all_fininte
icfaust Nov 27, 2024
63d9566
Update test_validation.py
icfaust Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onedal/dal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ namespace oneapi::dal::python {
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001
ONEDAL_PY_INIT_MODULE(logistic_regression);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
ONEDAL_PY_INIT_MODULE(finiteness_checker);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
#endif // ONEDAL_DATA_PARALLEL_SPMD

#ifdef ONEDAL_DATA_PARALLEL_SPMD
Expand Down Expand Up @@ -138,6 +141,9 @@ namespace oneapi::dal::python {
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001
init_logistic_regression(m);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
init_finiteness_checker(m);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240700
}
#endif // ONEDAL_DATA_PARALLEL_SPMD

Expand Down
4 changes: 4 additions & 0 deletions onedal/datatypes/table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ ONEDAL_PY_INIT_MODULE(table) {
const auto column_count = t.get_column_count();
return py::make_tuple(row_count, column_count);
});
table_obj.def_property_readonly("dtype", [](const table& t){
// returns a numpy dtype, even if source was not from numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this required? Will this be consistent for different array inputs?

Copy link
Contributor Author

@icfaust icfaust Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samir-nasibli good question! This simplifies dtype checking for the backend handoff for two reasons: 1) what oneDAL cares about most is what the dtype of the tables to it are (i.e. the output of to_table). While it happens that the dtypes of inputs to to_table match the dtype out, its not guaranteed. Doing it this way removes an indirection. 2) It is by luck that dpnp/dpctl allow for dtype comparisons to numpy dtypes, which is not the case for array_api_strict inputs. By having a consistent numpy dtype to oneDAL tables in python, we will not have to worry about supporting all the different dtypes of all the various array inputs at the param dict creation stage. We can then always compare to numpy dtypes in the params creation using the input tables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@icfaust Thank you for the clarification! I don't like this extension of table API which is wo for numpy and dpnp/dpctl.

I would suggest to solve the problem at python level https://github.com/intel/scikit-learn-intelex/blob/5bb54a520f5f3d1cf719e149df6526de943e393b/onedal/utils/validation.py#L452 without needing to extend table API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samir-nasibli I disagree, I think we should find someone to arbitrate on this.

Copy link
Contributor

@samir-nasibli samir-nasibli Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you can name it as numpy_dtype instead of dtype not to confuse user. But again I don't see a lot of profit having this attr for the numeric table. Let's have a broader discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the situation that this would be helping? Like if an alternative array backend is used then the dtype of the input to to_table may not match the output of from_table?

Copy link
Contributor Author

@icfaust icfaust Nov 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The large problem coming to us soon is handling zero-copy of arrays to the backend via the array_api standard. Comparison to numpy-dtypes when an object has different dtypes (for example an array_api_strict float64 dtype is not equivalent to a numpy float64 dtype, and will yield a false). We will have to do one of two things:

adapt the often:
"float" if dtype==np.float64 else "double" which occurs when creating param dicts
or
always use the tables dtype, which has already been translated into a oneDAL, appropriate datatype.

By using numpy as the datatype, we are not having to re-invent the wheel with creating our own dtype standard, as pybind11 supports it natively in C++, its already all over the codebase in sklearnex/onedal.

Secondarily, I have talked about how actually our backend can ingest complex float and complex double types, but will treat them as float types or double types. See here:

#2172 (comment)

return py::dtype(convert_dal_to_npy_type(t.get_metadata().get_data_type(0)));
});

#ifdef ONEDAL_DATA_PARALLEL
define_sycl_usm_array_property(table_obj);
Expand Down
103 changes: 103 additions & 0 deletions onedal/utils/finiteness_checker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

// fix error with missing headers
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20250200
#include "oneapi/dal/algo/finiteness_checker.hpp"
#else
#include "oneapi/dal/algo/finiteness_checker/compute.hpp"
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20250200

#include "onedal/common.hpp"
#include "onedal/version.hpp"

namespace py = pybind11;

namespace oneapi::dal::python {

template <typename Task, typename Ops>
struct method2t {
method2t(const Task& task, const Ops& ops) : ops(ops) {}

template <typename Float>
auto operator()(const py::dict& params) {
using namespace finiteness_checker;

const auto method = params["method"].cast<std::string>();

ONEDAL_PARAM_DISPATCH_VALUE(method, "dense", ops, Float, method::dense);
ONEDAL_PARAM_DISPATCH_VALUE(method, "by_default", ops, Float, method::by_default);
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(method);
}

Ops ops;
};

struct params2desc {
template <typename Float, typename Method, typename Task>
auto operator()(const pybind11::dict& params) {
using namespace dal::finiteness_checker;

auto desc = descriptor<Float, Method, Task>();
desc.set_allow_NaN(params["allow_nan"].cast<bool>());
return desc;
}
};

template <typename Policy, typename Task>
void init_compute_ops(py::module_& m) {
m.def("compute",
[](const Policy& policy,
const py::dict& params,
const table& data) {
using namespace finiteness_checker;
using input_t = compute_input<Task>;

compute_ops ops(policy, input_t{ data }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});
}

template <typename Task>
void init_compute_result(py::module_& m) {
using namespace finiteness_checker;
using result_t = compute_result<Task>;

py::class_<result_t>(m, "compute_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(finite, result_t);
}

ONEDAL_PY_TYPE2STR(finiteness_checker::task::compute, "compute");

ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_ops);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_compute_result);

ONEDAL_PY_INIT_MODULE(finiteness_checker) {
using namespace dal::detail;
using namespace finiteness_checker;
using namespace dal::finiteness_checker;

using task_list = types<task::compute>;
auto sub = m.def_submodule("finiteness_checker");

#ifndef ONEDAL_DATA_PARALLEL_SPMD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only instantiate if not data parallel spmd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of now there isn't an spmd implimentation in oneDAL. If you forsee this as a problem, we will have to go back to oneDAL and add it (we can make that available for next release if important). @ethanglaser let me know.

ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task_list);
#endif
}

} // namespace oneapi::dal::python
149 changes: 149 additions & 0 deletions onedal/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# ==============================================================================
Vika-F marked this conversation as resolved.
Show resolved Hide resolved
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import time

import numpy as np
import numpy.random as rand
import pytest
import scipy.sparse as sp

from onedal.tests.utils._dataframes_support import (
_convert_to_dataframe,
get_dataframes_and_queues,
)
from onedal.utils.validation import _assert_all_finite, assert_all_finite
icfaust marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize(
"shape",
[
[16, 2048],
[
2**16 + 3,
icfaust marked this conversation as resolved.
Show resolved Hide resolved
],
[1000, 1000],
[
3,
],
],
)
@pytest.mark.parametrize("allow_nan", [False, True])
@pytest.mark.parametrize(
"dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there issues with pandas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas inputs are not to be encountered by this function, check_array should always convert them to numpy inputs, we also do not support heterogeneous tables yet, which will be done at a later point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is handled in the follow up PR: #2177

)
def test_sum_infinite_actually_finite(dtype, shape, allow_nan, dataframe, queue):
X = np.array(shape, dtype=dtype)
X.fill(np.finfo(dtype).max)
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
_assert_all_finite(X, allow_nan=allow_nan)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize(
"shape",
[
[16, 2048],
[
2**16 + 3,
icfaust marked this conversation as resolved.
Show resolved Hide resolved
],
[1000, 1000],
[
3,
],
],
)
@pytest.mark.parametrize("allow_nan", [False, True])
@pytest.mark.parametrize("check", ["inf", "NaN", None])
@pytest.mark.parametrize("seed", [0, int(time.time())])
@pytest.mark.parametrize(
"dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl")
)
def test_assert_finite_random_location(
dtype, shape, allow_nan, check, seed, dataframe, queue
):
rand.seed(seed)
X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype)

if check:
loc = rand.randint(0, X.size - 1)
X.reshape((-1,))[loc] = float(check)

X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)

if check is None or (allow_nan and check == "NaN"):
_assert_all_finite(X, allow_nan=allow_nan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we checking that result of call is expected?

Copy link
Contributor Author

@icfaust icfaust Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_all_finite will raise a ValueError if a non-finite occurs, meaning we are using the native functionality of the function in the test to cause a fail (i.e. it will raise an error if there is a problem, in this case, no error should be expected).

else:
msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "."
with pytest.raises(ValueError, match=msg_err):
_assert_all_finite(X, allow_nan=allow_nan)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("allow_nan", [False, True])
@pytest.mark.parametrize("check", ["inf", "NaN", None])
@pytest.mark.parametrize("seed", [0, int(time.time())])
@pytest.mark.parametrize(
"dataframe, queue", get_dataframes_and_queues("numpy,dpnp,dpctl")
)
def test_assert_finite_random_shape_and_location(
dtype, allow_nan, check, seed, dataframe, queue
):
lb, ub = 2, 1048576 # lb is a patching condition, ub 2^20
rand.seed(seed)
X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype)

if check:
loc = rand.randint(0, X.size - 1)
X[loc] = float(check)

X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)

if check is None or (allow_nan and check == "NaN"):
_assert_all_finite(X, allow_nan=allow_nan)
else:
msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "."
with pytest.raises(ValueError, match=msg_err):
_assert_all_finite(X, allow_nan=allow_nan)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("allow_nan", [False, True])
@pytest.mark.parametrize("check", ["inf", "NaN", None])
@pytest.mark.parametrize("seed", [0, int(time.time())])
def test_assert_finite_sparse(dtype, allow_nan, check, seed):
lb, ub = 2, 2056
rand.seed(seed)
X = sp.random(
rand.randint(lb, ub),
rand.randint(lb, ub),
format="csr",
dtype=dtype,
random_state=rand.default_rng(seed),
)

if check:
locx = rand.randint(0, X.shape[0] - 1)
locy = rand.randint(0, X.shape[1] - 1)
X[locx, locy] = float(check)

if check is None or (allow_nan and check == "NaN"):
assert_all_finite(X, allow_nan=allow_nan)
else:
msg_err = "Input contains " + ("infinity" if allow_nan else "NaN, infinity") + "."
with pytest.raises(ValueError, match=msg_err):
assert_all_finite(X, allow_nan=allow_nan)
43 changes: 38 additions & 5 deletions onedal/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.validation import check_array

from daal4py.sklearn.utils.validation import _assert_all_finite
from daal4py.sklearn.utils.validation import (
_assert_all_finite as _daal4py_assert_all_finite,
)
from onedal import _backend
from onedal.common._policy import _get_policy
from onedal.datatypes import _convert_to_supported, to_table


class DataConversionWarning(UserWarning):
Expand Down Expand Up @@ -135,10 +140,10 @@ def _check_array(
if force_all_finite:
if sp.issparse(array):
if hasattr(array, "data"):
_assert_all_finite(array.data)
_daal4py_assert_all_finite(array.data)
force_all_finite = False
else:
_assert_all_finite(array)
_daal4py_assert_all_finite(array)
force_all_finite = False
array = check_array(
array=array,
Expand Down Expand Up @@ -200,7 +205,7 @@ def _check_X_y(
if y_numeric and y.dtype.kind == "O":
y = y.astype(np.float64)
if force_all_finite:
_assert_all_finite(y)
_daal4py_assert_all_finite(y)

lengths = [X.shape[0], y.shape[0]]
uniques = np.unique(lengths)
Expand Down Expand Up @@ -285,7 +290,7 @@ def _type_of_target(y):
# check float and contains non-integer float values
if y.dtype.kind == "f" and np.any(y != y.astype(int)):
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
_assert_all_finite(y)
_daal4py_assert_all_finite(y)
return "continuous" + suffix

if (len(np.unique(y)) > 2) or (y.ndim >= 2 and len(y[0]) > 1):
Expand Down Expand Up @@ -438,3 +443,31 @@ def _is_csr(x):
return isinstance(x, sp.csr_matrix) or (
hasattr(sp, "csr_array") and isinstance(x, sp.csr_array)
)


def _assert_all_finite(X, allow_nan=False, input_name=""):
policy = _get_policy(None, X)
X_t = to_table(_convert_to_supported(policy, X))
params = {
"fptype": "float" if X_t.dtype == np.float32 else "double",
"method": "dense",
"allow_nan": allow_nan,
}
if not _backend.finiteness_checker.compute.compute(policy, params, X_t).finite:
type_err = "infinity" if allow_nan else "NaN, infinity"
padded_input_name = input_name + " " if input_name else ""
msg_err = f"Input {padded_input_name}contains {type_err}."
icfaust marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg_err)


def assert_all_finite(
X,
*,
allow_nan=False,
input_name="",
):
_assert_all_finite(
X.data if sp.issparse(X) else X,
allow_nan=allow_nan,
input_name=input_name,
)
Comment on lines +454 to +464
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this main API for the feature? Is it used in the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test which uses the main function name, also testing sparse arrays as well.