From cb31797ad3189d30e32bbac2af9f369373488cdc Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Fri, 9 Aug 2024 18:07:49 +0200 Subject: [PATCH] Implement `dpnp.nan_to_num()` (#1966) * Implement dpnp.nan_to_num() * Update cupy tests for nan_to_num() * Add dpnp tests * Skip test_nan_to_num_scalar_nan * Applied review comments * Add more tests for nan_to_num() * Improve perfomance using out empty_like array * Add checks for nan, posinf, neginf args * Add type check for nan, posinf and neginf * Update tests * Add support boolean type --- dpnp/dpnp_iface_mathematical.py | 143 ++++++++++++++++++ tests/skipped_tests.tbl | 16 -- tests/skipped_tests_gpu.tbl | 16 -- tests/test_mathematical.py | 60 ++++++++ tests/test_sycl_queue.py | 14 ++ tests/test_usm_type.py | 10 ++ .../third_party/cupy/math_tests/test_misc.py | 14 +- 7 files changed, 235 insertions(+), 38 deletions(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 995bf7d36ff..44eedbb3918 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -110,6 +110,7 @@ "mod", "modf", "multiply", + "nan_to_num", "negative", "nextafter", "positive", @@ -130,6 +131,13 @@ ] +def _get_max_min(dtype): + """Get the maximum and minimum representable values for an inexact dtype.""" + + f = dpnp.finfo(dtype) + return f.max, f.min + + def _get_reduction_res_dt(a, dtype, _out): """Get a data type used by dpctl for result array in reduction function.""" @@ -2353,6 +2361,141 @@ def modf(x1, **kwargs): ) +def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): + """ + Replace ``NaN`` with zero and infinity with large finite numbers (default + behaviour) or with the numbers defined by the user using the `nan`, + `posinf` and/or `neginf` keywords. + + If `x` is inexact, ``NaN`` is replaced by zero or by the user defined value + in `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` or by the user defined value in + `posinf` keyword and -infinity is replaced by the most negative finite + floating point values representable by ``x.dtype`` or by the user defined + value in `neginf` keyword. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + If `x` is not inexact, then no replacements are made. + + For full documentation refer to :obj:`numpy.nan_to_num`. + + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + Input data. + copy : bool, optional + Whether to create a copy of `x` (``True``) or to replace values + in-place (``False``). The in-place operation only occurs if casting to + an array does not require a copy. + nan : {int, float, bool}, optional + Value to be used to fill ``NaN`` values. + Default: ``0.0``. + posinf : {int, float, bool, None}, optional + Value to be used to fill positive infinity values. If no value is + passed then positive infinity values will be replaced with a very + large number. + Default: ``None``. + neginf : {int, float, bool, None} optional + Value to be used to fill negative infinity values. If no value is + passed then negative infinity values will be replaced with a very + small (or negative) number. + Default: ``None``. + + Returns + ------- + out : dpnp.ndarray + `x`, with the non-finite values replaced. If `copy` is ``False``, this + may be `x` itself. + + See Also + -------- + :obj:`dpnp.isinf` : Shows which elements are positive or negative infinity. + :obj:`dpnp.isneginf` : Shows which elements are negative infinity. + :obj:`dpnp.isposinf` : Shows which elements are positive infinity. + :obj:`dpnp.isnan` : Shows which elements are Not a Number (NaN). + :obj:`dpnp.isfinite` : Shows which elements are finite + (not NaN, not infinity) + + Examples + -------- + >>> import dpnp as np + >>> np.nan_to_num(np.array(np.inf)) + array(1.79769313e+308) + >>> np.nan_to_num(np.array(-np.inf)) + array(-1.79769313e+308) + >>> np.nan_to_num(np.array(np.nan)) + array(0.) + >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128]) + >>> np.nan_to_num(x) + array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, + -1.28000000e+002, 1.28000000e+002]) + >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333) + array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02, + 1.2800000e+02]) + >>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)]) + >>> np.nan_to_num(y) + array([1.79769313e+308 +0.00000000e+000j, # may vary + 0.00000000e+000 +0.00000000e+000j, + 0.00000000e+000 +1.79769313e+308j]) + >>> np.nan_to_num(y, nan=111111, posinf=222222) + array([222222.+111111.j, 111111. +0.j, 111111.+222222.j]) + + """ + + dpnp.check_supported_arrays_type(x) + + # Python boolean is a subtype of an integer + # so additional check for bool is not needed. + if not isinstance(nan, (int, float)): + raise TypeError( + "nan must be a scalar of an integer, float, bool, " + f"but got {type(nan)}" + ) + + out = dpnp.empty_like(x) if copy else x + x_type = x.dtype.type + + if not issubclass(x_type, dpnp.inexact): + return x + + parts = ( + (x.real, x.imag) if issubclass(x_type, dpnp.complexfloating) else (x,) + ) + parts_out = ( + (out.real, out.imag) + if issubclass(x_type, dpnp.complexfloating) + else (out,) + ) + max_f, min_f = _get_max_min(x.real.dtype) + if posinf is not None: + if not isinstance(posinf, (int, float)): + raise TypeError( + "posinf must be a scalar of an integer, float, bool, " + f"or be None, but got {type(posinf)}" + ) + max_f = posinf + if neginf is not None: + if not isinstance(neginf, (int, float)): + raise TypeError( + "neginf must be a scalar of an integer, float, bool, " + f"or be None, but got {type(neginf)}" + ) + min_f = neginf + + for part, part_out in zip(parts, parts_out): + nan_mask = dpnp.isnan(part) + posinf_mask = dpnp.isposinf(part) + neginf_mask = dpnp.isneginf(part) + + part = dpnp.where(nan_mask, nan, part, out=part_out) + part = dpnp.where(posinf_mask, max_f, part, out=part_out) + part = dpnp.where(neginf_mask, min_f, part, out=part_out) + + return out + + _NEGATIVE_DOCSTRING = """ Computes the numerical negative for each element `x_i` of input array `x`. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 72afb525755..5007ef2f0bc 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -206,22 +206,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_par tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_2_{shapes=[(3, 2), (3, 4)]}::test_invalid_broadcast tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_3_{shapes=[(0,), (2,)]}::test_invalid_broadcast -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_negative -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_for_old_numpy -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_negative_for_old_numpy -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inf -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_nan -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inf_nan -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_nan_arg -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inf_arg -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_broadcast[nan] -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_broadcast[posinf] -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_broadcast[neginf] - -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_scalar_nan -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_copy -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inplace tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_real_dtypes tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_tol_real_dtypes tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_true diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index b1884abaabe..60665393efe 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -260,22 +260,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_par tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_2_{shapes=[(3, 2), (3, 4)]}::test_invalid_broadcast tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_param_3_{shapes=[(0,), (2,)]}::test_invalid_broadcast -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_negative -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_for_old_numpy -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_negative_for_old_numpy -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inf -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_nan -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inf_nan -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_nan_arg -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inf_arg -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_broadcast[nan] -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_broadcast[posinf] -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_broadcast[neginf] - -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_scalar_nan -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_copy -tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_inplace tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_real_dtypes tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_with_tol_real_dtypes tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_real_if_close_true diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 8b91e1876d3..6ae071ab394 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1116,6 +1116,66 @@ def test_subtract(self, dtype, lhs, rhs): self._test_mathematical("subtract", dtype, lhs, rhs, check_type=False) +class TestNanToNum: + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("shape", [(3,), (2, 3), (3, 2, 2)]) + def test_nan_to_num(self, dtype, shape): + a = numpy.random.randn(*shape).astype(dtype) + if not dpnp.issubdtype(dtype, dpnp.integer): + a.flat[1] = numpy.nan + a_dp = dpnp.array(a) + + result = dpnp.nan_to_num(a_dp) + expected = numpy.nan_to_num(a) + assert_allclose(result, expected) + + @pytest.mark.parametrize( + "data", [[], [numpy.nan], [numpy.inf], [-numpy.inf]] + ) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_empty_and_single_value_arrays(self, data, dtype): + a = numpy.array(data, dtype) + ia = dpnp.array(a) + + result = dpnp.nan_to_num(ia) + expected = numpy.nan_to_num(a) + assert_allclose(result, expected) + + def test_boolean_array(self): + a = numpy.array([True, False, numpy.nan], dtype=bool) + ia = dpnp.array(a) + + result = dpnp.nan_to_num(ia) + expected = numpy.nan_to_num(a) + assert_allclose(result, expected) + + def test_errors(self): + ia = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf]) + + # unsupported type `a` + a_np = dpnp.asnumpy(ia) + assert_raises(TypeError, dpnp.nan_to_num, a_np) + + # unsupported type `nan` + i_nan = dpnp.array(1) + assert_raises(TypeError, dpnp.nan_to_num, ia, nan=i_nan) + + # unsupported type `posinf` + i_posinf = dpnp.array(1) + assert_raises(TypeError, dpnp.nan_to_num, ia, posinf=i_posinf) + + # unsupported type `neginf` + i_neginf = dpnp.array(1) + assert_raises(TypeError, dpnp.nan_to_num, ia, neginf=i_neginf) + + @pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"]) + @pytest.mark.parametrize("value", [1 - 0j, [1, 2], (1,)]) + def test_errors_diff_types(self, kwarg, value): + ia = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf]) + with pytest.raises(TypeError): + dpnp.nan_to_num(ia, **{kwarg: value}) + + class TestNextafter: @pytest.mark.parametrize("dt", get_float_dtypes()) @pytest.mark.parametrize( diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 322bfcf3a78..848e7adcb95 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -2336,3 +2336,17 @@ def test_astype(device_x, device_y): sycl_queue = dpctl.SyclQueue(device_y) y = dpnp.astype(x, dtype="f4", device=sycl_queue) assert_sycl_queue_equal(y.sycl_queue, sycl_queue) + + +@pytest.mark.parametrize("copy", [True, False], ids=["True", "False"]) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_nan_to_num(copy, device): + a = dpnp.array([-dpnp.nan, -1, 0, 1, dpnp.nan], device=device) + result = dpnp.nan_to_num(a, copy=copy) + + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + assert copy == (result is not a) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index ab1956eb428..b1fa9ebe1c7 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -1361,3 +1361,13 @@ def test_histogram_bin_edges(usm_type_v, usm_type_w): assert v.usm_type == usm_type_v assert w.usm_type == usm_type_w assert edges.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) + + +@pytest.mark.parametrize("copy", [True, False], ids=["True", "False"]) +@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types) +def test_nan_to_num(copy, usm_type_a): + a = dp.array([-dp.nan, -1, 0, 1, dp.nan], usm_type=usm_type_a) + result = dp.nan_to_num(a, copy=copy) + + assert result.usm_type == usm_type_a + assert copy == (result is not a) diff --git a/tests/third_party/cupy/math_tests/test_misc.py b/tests/third_party/cupy/math_tests/test_misc.py index 62717803aca..e3251d84125 100644 --- a/tests/third_party/cupy/math_tests/test_misc.py +++ b/tests/third_party/cupy/math_tests/test_misc.py @@ -245,6 +245,7 @@ def test_nan_to_num_inf(self): def test_nan_to_num_nan(self): self.check_unary_nan("nan_to_num") + @pytest.mark.skip(reason="Scalar input is not supported") @testing.numpy_cupy_allclose(atol=1e-5) def test_nan_to_num_scalar_nan(self, xp): return xp.nan_to_num(xp.nan) @@ -260,26 +261,27 @@ def test_nan_to_num_inf_arg(self): @testing.numpy_cupy_array_equal() def test_nan_to_num_copy(self, xp): - x = xp.asarray([0, 1, xp.nan, 4], dtype=xp.float64) + x = xp.asarray([0, 1, xp.nan, 4], dtype=cupy.default_float_type()) y = xp.nan_to_num(x, copy=True) assert x is not y return y @testing.numpy_cupy_array_equal() def test_nan_to_num_inplace(self, xp): - x = xp.asarray([0, 1, xp.nan, 4], dtype=xp.float64) + x = xp.asarray([0, 1, xp.nan, 4], dtype=cupy.default_float_type()) y = xp.nan_to_num(x, copy=False) assert x is y return y + @pytest.mark.skip(reason="nan, posinf, neginf as array are not supported") @pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"]) def test_nan_to_num_broadcast(self, kwarg): for xp in (numpy, cupy): - x = xp.asarray([0, 1, xp.nan, 4], dtype=xp.float64) - y = xp.zeros((2, 4), dtype=xp.float64) - with pytest.raises(ValueError): + x = xp.asarray([0, 1, xp.nan, 4], dtype=cupy.default_float_type()) + y = xp.zeros((2, 4), dtype=cupy.default_float_type()) + with pytest.raises(TypeError): xp.nan_to_num(x, **{kwarg: y}) - with pytest.raises(ValueError): + with pytest.raises(TypeError): xp.nan_to_num(0.0, **{kwarg: y}) @testing.for_all_dtypes(no_bool=True, no_complex=True)