diff --git a/dpnp/backend/extensions/vm/hypot.hpp b/dpnp/backend/extensions/vm/hypot.hpp new file mode 100644 index 00000000000..6131d33b7f4 --- /dev/null +++ b/dpnp/backend/extensions/vm/hypot.hpp @@ -0,0 +1,81 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "common.hpp" +#include "types_matrix.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace vm +{ +template +sycl::event hypot_contig_impl(sycl::queue exec_q, + const std::int64_t n, + const char *in_a, + const char *in_b, + char *out_y, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(in_a); + const T *b = reinterpret_cast(in_b); + T *y = reinterpret_cast(out_y); + + return mkl_vm::hypot(exec_q, + n, // number of elements to be calculated + a, // pointer `a` containing 1st input vector of size n + b, // pointer `b` containing 2nd input vector of size n + y, // pointer `y` to the output vector of size n + depends); +} + +template +struct HypotContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename types::HypotOutputType::value_type, void>) + { + return nullptr; + } + else { + return hypot_contig_impl; + } + } +}; +} // namespace vm +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/vm/types_matrix.hpp b/dpnp/backend/extensions/vm/types_matrix.hpp index bbf74ed8c86..3d91c95bfbb 100644 --- a/dpnp/backend/extensions/vm/types_matrix.hpp +++ b/dpnp/backend/extensions/vm/types_matrix.hpp @@ -291,6 +291,21 @@ struct FloorOutputType dpctl_td_ns::DefaultResultEntry>::result_type; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::hypot function. + * + * @tparam T Type of input vectors `a` and `b` and of result vector `y`. + */ +template +struct HypotOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::BinaryTypeMapResultEntry, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; + /** * @brief A factory to define pairs of supported types for which * MKL VM library provides support in oneapi::mkl::vm::ln function. diff --git a/dpnp/backend/extensions/vm/vm_py.cpp b/dpnp/backend/extensions/vm/vm_py.cpp index e4d89ff6e84..1cf4fd7d854 100644 --- a/dpnp/backend/extensions/vm/vm_py.cpp +++ b/dpnp/backend/extensions/vm/vm_py.cpp @@ -45,6 +45,7 @@ #include "cosh.hpp" #include "div.hpp" #include "floor.hpp" +#include "hypot.hpp" #include "ln.hpp" #include "mul.hpp" #include "pow.hpp" @@ -74,11 +75,12 @@ static unary_impl_fn_ptr_t atan_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t atan2_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t atanh_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types]; +static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t cosh_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types]; -static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types]; +static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types]; static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types]; @@ -494,6 +496,35 @@ PYBIND11_MODULE(_vm_impl, m) py::arg("sycl_queue"), py::arg("src"), py::arg("dst")); } + // BinaryUfunc: ==== Hypot(x1, x2) ==== + { + vm_ext::init_ufunc_dispatch_vector( + hypot_dispatch_vector); + + auto hypot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2, + arrayT dst, const event_vecT &depends = {}) { + return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends, + hypot_dispatch_vector); + }; + m.def("_hypot", hypot_pyapi, + "Call `hypot` function from OneMKL VM library to compute element " + "by element hypotenuse of `x`", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("depends") = py::list()); + + auto hypot_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1, + arrayT src2, arrayT dst) { + return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst, + hypot_dispatch_vector); + }; + m.def("_mkl_hypot_to_call", hypot_need_to_call_pyapi, + "Check input arguments to answer if `hypot` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), + py::arg("dst")); + } + // UnaryUfunc: ==== Ln(x) ==== { vm_ext::init_ufunc_dispatch_vector, func_type_map_t::find_type>}), ...); - ((fmap[DPNPFuncName::DPNP_FN_HYPOT_EXT][FT1][FTs] = - {get_floating_res_type(), - (void *)dpnp_hypot_c_ext< - func_type_map_t::find_type()>, - func_type_map_t::find_type, - func_type_map_t::find_type>, - get_floating_res_type(), - (void *)dpnp_hypot_c_ext< - func_type_map_t::find_type< - get_floating_res_type()>, - func_type_map_t::find_type, - func_type_map_t::find_type>}), - ...); ((fmap[DPNPFuncName::DPNP_FN_MAXIMUM_EXT][FT1][FTs] = {get_floating_res_type(), (void *)dpnp_maximum_c_ext< diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 943fae54508..9513f9e085c 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -108,8 +108,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_FMOD_EXT DPNP_FN_FULL DPNP_FN_FULL_LIKE - DPNP_FN_HYPOT - DPNP_FN_HYPOT_EXT DPNP_FN_IDENTITY DPNP_FN_IDENTITY_EXT DPNP_FN_INV @@ -384,8 +382,6 @@ cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1) """ Mathematical functions """ -cpdef dpnp_descriptor dpnp_hypot(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, - dpnp_descriptor out=*, object where=*) cpdef dpnp_descriptor dpnp_fmax(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, dpnp_descriptor out=*, object where=*) cpdef dpnp_descriptor dpnp_fmin(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*, diff --git a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi index f903df9560c..da5f2ec1040 100644 --- a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi @@ -46,7 +46,6 @@ __all__ += [ "dpnp_fabs", "dpnp_fmod", "dpnp_gradient", - 'dpnp_hypot', "dpnp_fmax", "dpnp_fmin", "dpnp_modf", @@ -273,14 +272,6 @@ cpdef utils.dpnp_descriptor dpnp_gradient(utils.dpnp_descriptor y1, int dx=1): return result -cpdef utils.dpnp_descriptor dpnp_hypot(utils.dpnp_descriptor x1_obj, - utils.dpnp_descriptor x2_obj, - object dtype=None, - utils.dpnp_descriptor out=None, - object where=True): - return call_fptr_2in_1out_strides(DPNP_FN_HYPOT_EXT, x1_obj, x2_obj, dtype, out, where) - - cpdef utils.dpnp_descriptor dpnp_fmax(utils.dpnp_descriptor x1_obj, utils.dpnp_descriptor x2_obj, object dtype=None, diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 5e78b2eac6b..c9bb15433af 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -63,6 +63,7 @@ "dpnp_floor_divide", "dpnp_greater", "dpnp_greater_equal", + "dpnp_hypot", "dpnp_imag", "dpnp_invert", "dpnp_isfinite", @@ -1264,6 +1265,66 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"): return dpnp_array._create_from_usm_ndarray(res_usm) +_hypot_docstring_ = """ +hypot(x1, x2, out=None, order="K") +Calculates the hypotenuse for a right triangle with "legs" `x1_i` and `x2_i` of +input arrays `x1` and `x2`. +Args: + x1 (dpnp.ndarray): + First input array, expected to have a real-valued data type. + x2 (dpnp.ndarray): + Second input array, also expected to have a real-valued data type. + out ({None, dpnp.ndarray}, optional): + Output array to populate. + Array have the correct shape and the expected data type. + order ("C","F","A","K", None, optional): + Memory layout of the newly output array, if parameter `out` is `None`. + Default: "K". +Returns: + dpnp.ndarray: + An array containing the element-wise hypotenuse. The data type + of the returned array is determined by the Type Promotion Rules. +""" + + +def _call_hypot(src1, src2, dst, sycl_queue, depends=None): + """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_hypot_to_call(sycl_queue, src1, src2, dst): + # call pybind11 extension for hypot() function from OneMKL VM + return vmi._hypot(sycl_queue, src1, src2, dst, depends) + return ti._hypot(src1, src2, dst, sycl_queue, depends) + + +hypot_func = BinaryElementwiseFunc( + "hypot", + ti._hypot_result_type, + _call_hypot, + _hypot_docstring_, +) + + +def dpnp_hypot(x1, x2, out=None, order="K"): + """ + Invokes hypot() function from pybind11 extension of OneMKL VM if possible. + + Otherwise fully relies on dpctl.tensor implementation for hypot() function. + """ + + # dpctl.tensor only works with usm_ndarray or scalar + x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1) + x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) + out_usm = None if out is None else dpnp.get_usm_ndarray(out) + + res_usm = hypot_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order + ) + return dpnp_array._create_from_usm_ndarray(res_usm) + + _imag_docstring = """ imag(x, out=None, order="K") diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 3ed6de42e99..a68c9718b89 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -57,6 +57,7 @@ dpnp_atanh, dpnp_cos, dpnp_cosh, + dpnp_hypot, dpnp_log, dpnp_sin, dpnp_sinh, @@ -830,7 +831,18 @@ def expm1(x1): return call_origin(numpy.expm1, x1) -def hypot(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs): +def hypot( + x1, + x2, + /, + out=None, + *, + where=True, + order="K", + dtype=None, + subok=True, + **kwargs, +): """ Given the "legs" of a right triangle, return its hypotenuse. @@ -848,7 +860,7 @@ def hypot(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs): Parameters `where`, `dtype` and `subok` are supported with their default values. Keyword argument `kwargs` is currently unsupported. Otherwise the function will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. + Input array data types are limited by supported real-valued data types. Examples -------- @@ -869,60 +881,17 @@ def hypot(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs): """ - if kwargs: - pass - elif where is not True: - pass - elif dtype is not None: - pass - elif subok is not True: - pass - elif dpnp.isscalar(x1) and dpnp.isscalar(x2): - # at least either x1 or x2 has to be an array - pass - else: - # get USM type and queue to copy scalar from the host memory into a USM allocation - usm_type, queue = ( - get_usm_allocations([x1, x2]) - if dpnp.isscalar(x1) or dpnp.isscalar(x2) - else (None, None) - ) - - x1_desc = dpnp.get_dpnp_descriptor( - x1, - copy_when_strides=False, - copy_when_nondefault_queue=False, - alloc_usm_type=usm_type, - alloc_queue=queue, - ) - x2_desc = dpnp.get_dpnp_descriptor( - x2, - copy_when_strides=False, - copy_when_nondefault_queue=False, - alloc_usm_type=usm_type, - alloc_queue=queue, - ) - if x1_desc and x2_desc: - if out is not None: - if not dpnp.is_supported_array_type(out): - raise TypeError( - "return array must be of supported array type" - ) - out_desc = ( - dpnp.get_dpnp_descriptor( - out, copy_when_nondefault_queue=False - ) - or None - ) - else: - out_desc = None - - return dpnp_hypot( - x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where - ).get_pyobj() - - return call_origin( - numpy.hypot, x1, x2, dtype=dtype, out=out, where=where, **kwargs + return check_nd_call_func( + numpy.hypot, + dpnp_hypot, + x1, + x2, + out=out, + where=where, + order=order, + dtype=dtype, + subok=subok, + **kwargs, ) diff --git a/tests/skipped_tests_gpu_no_fp64.tbl b/tests/skipped_tests_gpu_no_fp64.tbl index 21aad4100b5..e47d19e0364 100644 --- a/tests/skipped_tests_gpu_no_fp64.tbl +++ b/tests/skipped_tests_gpu_no_fp64.tbl @@ -409,7 +409,6 @@ tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_3 tests/third_party/cupy/math_tests/test_sumprod.py::TestDiff::test_diff_2dim_with_scalar_append tests/third_party/cupy/math_tests/test_trigonometric.py::TestTrigonometric::test_deg2rad -tests/third_party/cupy/math_tests/test_trigonometric.py::TestTrigonometric::test_hypot tests/third_party/cupy/math_tests/test_trigonometric.py::TestTrigonometric::test_rad2deg tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_6_{a_shape=(3, 2), b_shape=(3, 2), shape=(4, 3, 2)}::test_beta diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index abf2fb18de1..85d015f3775 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -199,7 +199,6 @@ def test_floor_divide(self, dtype, lhs, rhs): "floor_divide", dtype, lhs, rhs, check_type=False ) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_complex=True) ) @@ -966,6 +965,90 @@ def test_invalid_out(self, out): assert_raises(TypeError, numpy.add, a.asnumpy(), 2, out) +class TestHypot: + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_hypot(self, dtype): + array1_data = numpy.arange(10) + array2_data = numpy.arange(5, 15) + out = numpy.empty(10, dtype=dtype) + + # DPNP + dp_array1 = dpnp.array(array1_data, dtype=dtype) + dp_array2 = dpnp.array(array2_data, dtype=dtype) + dp_out = dpnp.array(out, dtype=dtype) + result = dpnp.hypot(dp_array1, dp_array2, out=dp_out) + + # original + np_array1 = numpy.array(array1_data, dtype=dtype) + np_array2 = numpy.array(array2_data, dtype=dtype) + expected = numpy.hypot(np_array1, np_array2, out=out) + + assert_allclose(expected, result) + assert_allclose(out, dp_out) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_out_dtypes(self, dtype): + size = 10 + + np_array1 = numpy.arange(size, 2 * size, dtype=dtype) + np_array2 = numpy.arange(size, dtype=dtype) + np_out = numpy.empty(size, dtype=numpy.float32) + expected = numpy.hypot(np_array1, np_array2, out=np_out) + + dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype) + dp_array2 = dpnp.arange(size, dtype=dtype) + + dp_out = dpnp.empty(size, dtype=dpnp.float32) + if dtype != dpnp.float32: + # dtype of out mismatches types of input arrays + with pytest.raises(TypeError): + dpnp.hypot(dp_array1, dp_array2, out=dp_out) + + # allocate new out with expected type + dp_out = dpnp.empty(size, dtype=dtype) + + result = dpnp.hypot(dp_array1, dp_array2, out=dp_out) + + tol = numpy.finfo(numpy.float32).resolution + assert_allclose(expected, result, rtol=tol, atol=tol) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_out_overlap(self, dtype): + size = 15 + # DPNP + dp_a = dpnp.arange(2 * size, dtype=dtype) + dpnp.hypot(dp_a[size::], dp_a[::2], out=dp_a[:size:]) + + # original + np_a = numpy.arange(2 * size, dtype=dtype) + numpy.hypot(np_a[size::], np_a[::2], out=np_a[:size:]) + + tol = numpy.finfo(numpy.float32).resolution + assert_allclose(np_a, dp_a, rtol=tol, atol=tol) + + @pytest.mark.parametrize( + "shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"] + ) + def test_invalid_shape(self, shape): + dp_array1 = dpnp.arange(10) + dp_array2 = dpnp.arange(5, 15) + dp_out = dpnp.empty(shape) + + with pytest.raises(ValueError): + dpnp.hypot(dp_array1, dp_array2, out=dp_out) + + @pytest.mark.parametrize( + "out", + [4, (), [], (3, 7), [2, 4]], + ids=["4", "()", "[]", "(3, 7)", "[2, 4]"], + ) + def test_invalid_out(self, out): + a = dpnp.arange(10) + + assert_raises(TypeError, dpnp.hypot, a, 2, out) + assert_raises(TypeError, numpy.hypot, a.asnumpy(), 2, out) + + class TestFmax: @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 8168b8057cc..56713ca638a 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -364,6 +364,11 @@ def test_proj(device): [-3.0, -2.0, -1.0, 1.0, 2.0, 3.0], [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], ), + pytest.param( + "hypot", + [[1.0, 2.0, 3.0, 4.0]], + [[-1.0, -2.0, -4.0, -5.0]], + ), pytest.param( "matmul", [[1.0, 0.0], [0.0, 1.0]], [[4.0, 1.0], [1.0, 2.0]] ), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 2ebf0cdb2e3..116fd995d79 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -366,6 +366,11 @@ def test_1in_1out(func, data, usm_type): [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]], ), + pytest.param( + "hypot", + [[1.0, 2.0, 3.0, 4.0]], + [[-1.0, -2.0, -4.0, -5.0]], + ), pytest.param( "maximum", [[0.0, 1.0, 2.0]],