From 6d3be5d859d1c9ceef46a78f7b6cd6ca76948f2c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 20 Sep 2023 08:31:15 -0500 Subject: [PATCH 1/4] Use sycl::ext::oneapi::experimental for complex tyes This works around use of double precision functions/literals in implementations of these functions in MSVC headers, causing failures to offload on Iris Xe for single precision input citing lack of fp64 support by the hardware. Changes include CL/sycl.hpp to sycl/sycl.hpp per SYCL-2020 spec For every CMake target, where add_sycl_to_target is used, we also run target_compile_options( ${target_name} PRIVATE -fysl-targets=spir64-unknown-unknown,nvptx64-nvidia-cuda ) Add DPCTL_TARGET_CUDA Boolean cmake option Also DPCTL_SYCL_TARGETS parameter can be used to specify targets to build for. DPCTL_TARGET_CUDA could be set via cmake option, or via environment variable, e.g. ``` $ DPCTL_TARGET_CUDA=1 python scripts/build_locally.py --verbose ``` This calls `target_compile_options` to set sycl-targets for targets needing SYCL --- CMakeLists.txt | 21 + dpctl/CMakeLists.txt | 15 +- dpctl/_host_task_util.hpp | 2 +- dpctl/apis/include/dpctl4pybind11.hpp | 2 +- dpctl/sycl.pxd | 2 +- dpctl/tensor/CMakeLists.txt | 29 +- .../include/kernels/accumulators.hpp | 2 +- .../kernels/boolean_advanced_indexing.hpp | 2 +- .../include/kernels/boolean_reductions.hpp | 2 +- .../include/kernels/constructors.hpp | 2 +- .../include/kernels/copy_and_cast.hpp | 2 +- .../kernels/elementwise_functions/abs.hpp | 6 +- .../kernels/elementwise_functions/acos.hpp | 15 +- .../kernels/elementwise_functions/acosh.hpp | 13 +- .../kernels/elementwise_functions/add.hpp | 30 +- .../kernels/elementwise_functions/asin.hpp | 22 +- .../kernels/elementwise_functions/asinh.hpp | 16 +- .../kernels/elementwise_functions/atan.hpp | 7 +- .../kernels/elementwise_functions/atan2.hpp | 2 +- .../kernels/elementwise_functions/atanh.hpp | 7 +- .../elementwise_functions/bitwise_and.hpp | 2 +- .../elementwise_functions/bitwise_invert.hpp | 2 +- .../bitwise_left_shift.hpp | 2 +- .../elementwise_functions/bitwise_or.hpp | 2 +- .../bitwise_right_shift.hpp | 2 +- .../elementwise_functions/bitwise_xor.hpp | 2 +- .../kernels/elementwise_functions/cbrt.hpp | 2 +- .../kernels/elementwise_functions/ceil.hpp | 2 +- .../kernels/elementwise_functions/common.hpp | 2 +- .../elementwise_functions/common_inplace.hpp | 2 +- .../kernels/elementwise_functions/conj.hpp | 8 +- .../elementwise_functions/copysign.hpp | 2 +- .../kernels/elementwise_functions/cos.hpp | 7 +- .../kernels/elementwise_functions/cosh.hpp | 7 +- .../kernels/elementwise_functions/equal.hpp | 17 +- .../kernels/elementwise_functions/exp.hpp | 7 +- .../kernels/elementwise_functions/exp2.hpp | 6 +- .../kernels/elementwise_functions/expm1.hpp | 2 +- .../kernels/elementwise_functions/floor.hpp | 2 +- .../elementwise_functions/floor_divide.hpp | 2 +- .../kernels/elementwise_functions/greater.hpp | 2 +- .../elementwise_functions/greater_equal.hpp | 2 +- .../kernels/elementwise_functions/hypot.hpp | 2 +- .../kernels/elementwise_functions/imag.hpp | 2 +- .../elementwise_functions/isfinite.hpp | 3 +- .../kernels/elementwise_functions/isinf.hpp | 2 +- .../kernels/elementwise_functions/isnan.hpp | 2 +- .../kernels/elementwise_functions/less.hpp | 2 +- .../elementwise_functions/less_equal.hpp | 2 +- .../kernels/elementwise_functions/log.hpp | 12 +- .../kernels/elementwise_functions/log10.hpp | 8 +- .../kernels/elementwise_functions/log1p.hpp | 2 +- .../kernels/elementwise_functions/log2.hpp | 8 +- .../elementwise_functions/logaddexp.hpp | 2 +- .../elementwise_functions/logical_and.hpp | 2 +- .../elementwise_functions/logical_not.hpp | 2 +- .../elementwise_functions/logical_or.hpp | 2 +- .../elementwise_functions/logical_xor.hpp | 2 +- .../kernels/elementwise_functions/maximum.hpp | 2 +- .../kernels/elementwise_functions/minimum.hpp | 2 +- .../elementwise_functions/multiply.hpp | 17 +- .../elementwise_functions/negative.hpp | 2 +- .../elementwise_functions/not_equal.hpp | 2 +- .../elementwise_functions/positive.hpp | 2 +- .../kernels/elementwise_functions/pow.hpp | 26 +- .../kernels/elementwise_functions/proj.hpp | 2 +- .../kernels/elementwise_functions/real.hpp | 2 +- .../elementwise_functions/remainder.hpp | 2 +- .../kernels/elementwise_functions/round.hpp | 2 +- .../kernels/elementwise_functions/rsqrt.hpp | 2 +- .../kernels/elementwise_functions/sign.hpp | 23 +- .../kernels/elementwise_functions/signbit.hpp | 2 +- .../kernels/elementwise_functions/sin.hpp | 7 +- .../kernels/elementwise_functions/sinh.hpp | 6 +- .../kernels/elementwise_functions/sqrt.hpp | 9 +- .../kernels/elementwise_functions/square.hpp | 15 +- .../elementwise_functions/subtract.hpp | 2 +- .../kernels/elementwise_functions/tan.hpp | 6 +- .../kernels/elementwise_functions/tanh.hpp | 7 +- .../elementwise_functions/true_divide.hpp | 50 +- .../kernels/elementwise_functions/trunc.hpp | 2 +- .../kernels/integer_advanced_indexing.hpp | 2 +- .../libtensor/include/kernels/reductions.hpp | 3 +- .../libtensor/include/kernels/repeat.hpp | 2 +- .../libtensor/include/kernels/where.hpp | 2 +- .../libtensor/include/utils/offset_utils.hpp | 2 +- .../libtensor/include/utils/sycl_utils.hpp | 2 +- .../libtensor/include/utils/type_dispatch.hpp | 2 +- .../libtensor/include/utils/type_utils.hpp | 2 +- .../tensor/libtensor/source/accumulators.cpp | 2 +- .../tensor/libtensor/source/accumulators.hpp | 2 +- .../source/boolean_advanced_indexing.cpp | 2 +- .../source/boolean_advanced_indexing.hpp | 2 +- .../libtensor/source/boolean_reductions.cpp | 2 +- .../libtensor/source/boolean_reductions.hpp | 2 +- .../source/copy_and_cast_usm_to_usm.cpp | 2 +- .../source/copy_and_cast_usm_to_usm.hpp | 2 +- .../libtensor/source/copy_for_reshape.cpp | 2 +- .../libtensor/source/copy_for_reshape.hpp | 2 +- .../tensor/libtensor/source/copy_for_roll.cpp | 2 +- .../tensor/libtensor/source/copy_for_roll.hpp | 2 +- .../copy_numpy_ndarray_into_usm_ndarray.cpp | 2 +- .../copy_numpy_ndarray_into_usm_ndarray.hpp | 2 +- .../source/device_support_queries.cpp | 2 +- .../source/device_support_queries.hpp | 2 +- .../source/elementwise_functions.cpp | 5155 +++++++++++++++++ .../elementwise_functions.hpp | 3 +- dpctl/tensor/libtensor/source/eye_ctor.cpp | 2 +- dpctl/tensor/libtensor/source/eye_ctor.hpp | 2 +- dpctl/tensor/libtensor/source/full_ctor.cpp | 2 +- dpctl/tensor/libtensor/source/full_ctor.hpp | 2 +- .../source/integer_advanced_indexing.cpp | 2 +- .../source/integer_advanced_indexing.hpp | 2 +- .../libtensor/source/linear_sequences.cpp | 2 +- .../libtensor/source/linear_sequences.hpp | 2 +- .../libtensor/source/reduction_over_axis.cpp | 514 ++ .../libtensor/source/reduction_over_axis.hpp | 691 +++ dpctl/tensor/libtensor/source/repeat.cpp | 2 +- dpctl/tensor/libtensor/source/repeat.hpp | 2 +- .../tensor/libtensor/source/tensor_ctors.cpp | 2 +- dpctl/tensor/libtensor/source/triul_ctor.cpp | 2 +- dpctl/tensor/libtensor/source/triul_ctor.hpp | 2 +- dpctl/tensor/libtensor/source/where.cpp | 2 +- dpctl/tensor/libtensor/source/where.hpp | 2 +- dpctl/utils/CMakeLists.txt | 13 + libsyclinterface/CMakeLists.txt | 13 + .../helper/include/dpctl_error_handlers.h | 2 +- .../helper/include/dpctl_utils_helper.h | 2 +- .../include/dpctl_device_selection.hpp | 2 +- .../include/dpctl_sycl_type_casters.hpp | 2 +- .../source/dpctl_device_selection.cpp | 2 +- .../source/dpctl_sycl_context_interface.cpp | 2 +- .../source/dpctl_sycl_device_interface.cpp | 2 +- .../source/dpctl_sycl_device_manager.cpp | 2 +- .../dpctl_sycl_device_selector_interface.cpp | 2 +- .../source/dpctl_sycl_event_interface.cpp | 2 +- .../dpctl_sycl_kernel_bundle_interface.cpp | 14 +- .../source/dpctl_sycl_kernel_interface.cpp | 2 +- .../source/dpctl_sycl_platform_interface.cpp | 2 +- .../source/dpctl_sycl_platform_manager.cpp | 2 +- .../source/dpctl_sycl_queue_interface.cpp | 2 +- .../source/dpctl_sycl_queue_manager.cpp | 2 +- .../source/dpctl_sycl_usm_interface.cpp | 2 +- libsyclinterface/tests/CMakeLists.txt | 13 + libsyclinterface/tests/test_helper.cpp | 2 +- .../tests/test_sycl_context_interface.cpp | 2 +- .../tests/test_sycl_device_aspects.cpp | 2 +- .../tests/test_sycl_device_interface.cpp | 2 +- .../test_sycl_device_invalid_filters.cpp | 2 +- .../test_sycl_device_selector_interface.cpp | 2 +- .../tests/test_sycl_device_subdevices.cpp | 2 +- .../tests/test_sycl_event_interface.cpp | 2 +- .../test_sycl_kernel_bundle_interface.cpp | 2 +- .../tests/test_sycl_kernel_interface.cpp | 2 +- .../tests/test_sycl_platform_interface.cpp | 2 +- .../test_sycl_platform_invalid_filters.cpp | 2 +- .../tests/test_sycl_queue_interface.cpp | 2 +- .../tests/test_sycl_queue_manager.cpp | 2 +- .../tests/test_sycl_queue_submit.cpp | 2 +- .../tests/test_sycl_usm_interface.cpp | 2 +- 160 files changed, 6871 insertions(+), 218 deletions(-) create mode 100644 dpctl/tensor/libtensor/source/elementwise_functions.cpp create mode 100644 dpctl/tensor/libtensor/source/reduction_over_axis.cpp create mode 100644 dpctl/tensor/libtensor/source/reduction_over_axis.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index eb53db12ec..adfb4fbddd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,9 +17,30 @@ option(DPCTL_GENERATE_COVERAGE "Build dpctl with coverage instrumentation" OFF ) +option(DPCTL_TARGET_CUDA + "Build DPCTL to target CUDA devices" + OFF +) find_package(IntelSYCL REQUIRED PATHS ${CMAKE_SOURCE_DIR}/cmake NO_DEFAULT_PATH) +set(_dpctl_sycl_targets) +if ("x${DPCTL_SYCL_TARGETS}" STREQUAL "x") + if(DPCTL_TARGET_CUDA) + set(_dpctl_sycl_targets "nvptx64-nvidia-cuda,spir64-unknown-unknown") + else() + if(DEFINED ENV{DPCTL_TARGET_CUDA}) + set(_dpctl_sycl_targets "nvptx64-nvidia-cuda,spir64-unknown-unknown") + endif() + endif() +else() + set(_dpctl_sycl_targets ${DPCTL_SYCL_TARGETS}) +endif() + +if(_dpctl_sycl_targets) + message(STATUS "Compiling for -fsycl-targets=${_dpctl_sycl_targets}") +endif() + add_subdirectory(libsyclinterface) file(GLOB _dpctl_capi_headers dpctl/apis/include/*.h*) diff --git a/dpctl/CMakeLists.txt b/dpctl/CMakeLists.txt index cb872ff45f..616f270ad3 100644 --- a/dpctl/CMakeLists.txt +++ b/dpctl/CMakeLists.txt @@ -143,7 +143,20 @@ function(build_dpctl_ext _trgt _src _dest) add_custom_target(${_cythonize_trgt} DEPENDS ${_src}) Python_add_library(${_trgt} MODULE WITH_SOABI ${_generated_src}) if (BUILD_DPCTL_EXT_SYCL) - add_sycl_to_target(TARGET ${_trgt} SOURCES ${_generated_src}) + add_sycl_to_target(TARGET ${_trgt} SOURCES ${_generated_src}) + if(_dpctl_sycl_targets) + # make fat binary + target_compile_options( + ${_trgt} + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) + target_link_options( + ${_trgt} + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) + endif() endif() target_include_directories(${_trgt} PRIVATE ${NumPy_INCLUDE_DIR} ${DPCTL_INCLUDE_DIR}) add_dependencies(${_trgt} _build_time_create_dpctl_include_copy ${_cythonize_trgt}) diff --git a/dpctl/_host_task_util.hpp b/dpctl/_host_task_util.hpp index cb3828a54f..8349f4c0d9 100644 --- a/dpctl/_host_task_util.hpp +++ b/dpctl/_host_task_util.hpp @@ -33,7 +33,7 @@ #include "Python.h" #include "syclinterface/dpctl_data_types.h" #include "syclinterface/dpctl_sycl_type_casters.hpp" -#include +#include DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef QRef, PyObject **obj_array, diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index f68826af48..10ee4602c3 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -26,10 +26,10 @@ #pragma once #include "dpctl_capi.h" -#include #include #include #include +#include #include #include diff --git a/dpctl/sycl.pxd b/dpctl/sycl.pxd index 918f476298..0318868ef8 100644 --- a/dpctl/sycl.pxd +++ b/dpctl/sycl.pxd @@ -20,7 +20,7 @@ from . cimport _backend as dpctl_backend -cdef extern from "CL/sycl.hpp" namespace "sycl": +cdef extern from "sycl/sycl.hpp" namespace "sycl": cdef cppclass queue "sycl::queue": pass diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index d1de208805..f2454a9fdc 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -188,12 +188,20 @@ foreach(_src_fn ${_no_fast_math_sources}) ) endforeach() if (UNIX) - set_source_files_properties( - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp - PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES") + set(_compiler_definitions "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX") +else() + set(_compiler_definitions "SYCL_EXT_ONEAPI_COMPLEX") endif() +foreach(_src_fn ${_elementwise_sources}) + get_source_file_property(_cmpl_options_defs ${_src_fn} COMPILE_DEFINITIONS) + set(_combined_options_defs ${_cmpl_options_defs} "${_compiler_definitions}") + set_source_files_properties( + ${_src_fn} + PROPERTIES COMPILE_DEFINITIONS "${_combined_options_defs}" + ) +endforeach() + set(_linker_options "LINKER:${DPCTL_LDFLAGS}") foreach(python_module_name ${_py_trgts}) target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) @@ -209,6 +217,19 @@ foreach(python_module_name ${_py_trgts}) ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/ ) target_link_options(${python_module_name} PRIVATE ${_linker_options}) + if(_dpctl_sycl_targets) + # make fat binary + target_compile_options( + ${python_module_name} + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) + target_link_options( + ${python_module_name} + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) + endif() add_dependencies(${python_module_name} _dpctl4pybind11_deps) install(TARGETS ${python_module_name} DESTINATION "dpctl/tensor") endforeach() diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index a8ef1c423e..bd6ad20b6d 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -23,11 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp index 968459fb68..522baadc6d 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp @@ -23,10 +23,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp index 9736b2c2a3..61fb0f6ba0 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp @@ -24,7 +24,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/constructors.hpp b/dpctl/tensor/libtensor/include/kernels/constructors.hpp index 8870e26ac2..c28033d23d 100644 --- a/dpctl/tensor/libtensor/include/kernels/constructors.hpp +++ b/dpctl/tensor/libtensor/include/kernels/constructors.hpp @@ -27,9 +27,9 @@ #include "utils/offset_utils.hpp" #include "utils/strided_iters.hpp" #include "utils/type_utils.hpp" -#include #include #include +#include namespace dpctl { diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index 0db1f071a1..9d1c788626 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -23,10 +23,10 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index bcf6a28040..ab321ad356 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -23,12 +23,13 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -49,6 +50,7 @@ namespace abs namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -120,7 +122,7 @@ template struct AbsFunctor } else { #ifdef USE_STD_ABS_FOR_COMPLEX_TYPES - return std::abs(z); + return exprm_ns::abs(exprm_ns::complex(z)); #else return std::hypot(std::real(z), std::imag(z)); #endif diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index ac1d597c93..23a87b9d44 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace acos namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -103,10 +105,12 @@ template struct AcosFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (std::abs(x) > r_eps || std::abs(y) > r_eps) { - argT log_in = std::log(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT log_in = + exprm_ns::log(exprm_ns::complex(in)); - const realT wx = std::real(log_in); - const realT wy = std::imag(log_in); + const realT wx = log_in.real(); + const realT wy = log_in.imag(); const realT rx = std::abs(wy); realT ry = wx + std::log(realT(2)); @@ -114,7 +118,8 @@ template struct AcosFunctor } /* ordinary cases */ - return std::acos(in); + return exprm_ns::acos( + exprm_ns::complex(in)); // std::acos(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index 484b0da8a6..56730a411c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace acosh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -110,15 +112,18 @@ template struct AcoshFunctor * For large x or y including acos(+-Inf + I*+-Inf) */ if (std::abs(x) > r_eps || std::abs(y) > r_eps) { - const realT wx = std::real(std::log(in)); - const realT wy = std::imag(std::log(in)); + using sycl_complexT = typename exprm_ns::complex; + const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in)); + const realT wx = log_in.real(); + const realT wy = log_in.imag(); const realT rx = std::abs(wy); realT ry = wx + std::log(realT(2)); acos_in = resT{rx, (std::signbit(y)) ? ry : -ry}; } else { /* ordinary cases */ - acos_in = std::acos(in); + acos_in = exprm_ns::acos( + exprm_ns::complex(in)); // std::acos(in); } /* Now we calculate acosh(z) */ diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index df6797845f..0ed1710833 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -24,9 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include +#include #include #include "utils/offset_utils.hpp" @@ -49,6 +50,7 @@ namespace add namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct AddFunctor { @@ -60,7 +62,31 @@ template struct AddFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return in1 + in2; + if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using rT1 = typename argT1::value_type; + using rT2 = typename argT2::value_type; + + return exprm_ns::complex(in1) + exprm_ns::complex(in2); + } + else if constexpr (tu_ns::is_complex::value && + !tu_ns::is_complex::value) + { + using rT1 = typename argT1::value_type; + + return exprm_ns::complex(in1) + in2; + } + else if constexpr (!tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using rT2 = typename argT2::value_type; + + return in1 + exprm_ns::complex(in2); + } + else { + return in1 + in2; + } } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 8b960dd30d..035480c437 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace asin namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -117,24 +119,26 @@ template struct AsinFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (std::abs(x) > r_eps || std::abs(y) > r_eps) { - const resT z = {x, y}; + using sycl_complexT = exprm_ns::complex; + const sycl_complexT z{x, y}; realT wx, wy; if (!std::signbit(x)) { - auto log_z = std::log(z); - wx = std::real(log_z) + std::log(realT(2)); - wy = std::imag(log_z); + auto log_z = exprm_ns::log(z); + wx = log_z.real() + std::log(realT(2)); + wy = log_z.imag(); } else { - auto log_mz = std::log(-z); - wx = std::real(log_mz) + std::log(realT(2)); - wy = std::imag(log_mz); + auto log_mz = exprm_ns::log(-z); + wx = log_mz.real() + std::log(realT(2)); + wy = log_mz.imag(); } const realT asinh_re = std::copysign(wx, x); const realT asinh_im = std::copysign(wy, y); return resT{asinh_im, asinh_re}; } /* ordinary cases */ - return std::asin(in); + return exprm_ns::asin( + exprm_ns::complex(in)); // std::asin(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 271a861cfe..523ca4f01f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace asinh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -106,16 +108,20 @@ template struct AsinhFunctor realT(1) / std::numeric_limits::epsilon(); if (std::abs(x) > r_eps || std::abs(y) > r_eps) { - resT log_in = (std::signbit(x)) ? std::log(-in) : std::log(in); - realT wx = std::real(log_in) + std::log(realT(2)); - realT wy = std::imag(log_in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT log_in = (std::signbit(x)) + ? exprm_ns::log(sycl_complexT(-in)) + : exprm_ns::log(sycl_complexT(in)); + realT wx = log_in.real() + std::log(realT(2)); + realT wy = log_in.imag(); const realT res_re = std::copysign(wx, x); const realT res_im = std::copysign(wy, y); return resT{res_re, res_im}; } /* ordinary cases */ - return std::asinh(in); + return exprm_ns::asinh( + exprm_ns::complex(in)); // std::asinh(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index f1dcce2831..df8bba538b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -23,11 +23,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace atan namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -126,7 +128,8 @@ template struct AtanFunctor return resT{atanh_im, atanh_re}; } /* ordinary cases */ - return std::atan(in); + return exprm_ns::atan( + exprm_ns::complex(in)); // std::atan(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp index 765c0fe0c3..8df1667312 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 56432d7808..d6a4b06ac3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -23,11 +23,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace atanh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -119,7 +121,8 @@ template struct AtanhFunctor return resT{res_re, res_im}; } /* ordinary cases */ - return std::atanh(in); + return exprm_ns::atanh( + exprm_ns::complex(in)); // std::atanh(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp index d88d17d3e3..b9a0d41f93 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp @@ -23,9 +23,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp index ed4aeeb59e..93e715dc4d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp index 5cfd6ca5e3..b13971e27a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp index d5669d41b1..a07d4ed540 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp @@ -23,9 +23,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp index 5a04165701..ce8537bc31 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp index ec8192fd0f..e2ce5a5703 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp @@ -23,9 +23,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp index 1d4aa65002..92584f0dfe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp index 76fa80c287..0059064ec1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp @@ -23,10 +23,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index c0a94be341..5dc4728a65 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -23,10 +23,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp index 614c7f4092..c4f893a532 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include namespace dpctl { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 3b0a1584de..6977e3a747 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -24,11 +24,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -49,6 +50,7 @@ namespace conj namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -68,7 +70,9 @@ template struct ConjFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - return std::conj(in); + using rT = typename argT::value_type; + + return exprm_ns::conj(exprm_ns::complex(in)); // std::conj(in); } else { if constexpr (!std::is_same_v) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp index b1997d06b4..43e06cb281 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index f7c66d5f68..bdc1acc1fe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace cos namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -81,7 +83,8 @@ template struct CosFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { - return std::cos(in); + return exprm_ns::cos( + exprm_ns::complex(in)); // std::cos(in); } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index fbcc7e40f9..7093d2a2a3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace cosh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -81,7 +83,8 @@ template struct CoshFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return std::cosh(in); + return exprm_ns::cosh( + exprm_ns::complex(in)); // std::cosh(in); } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index cd726f72ea..6d68861396 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -24,9 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include +#include #include #include "utils/offset_utils.hpp" @@ -48,6 +49,7 @@ namespace equal namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct EqualFunctor { @@ -62,7 +64,18 @@ template struct EqualFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return (in1 == in2); + if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using realT1 = typename argT1::value_type; + using realT2 = typename argT2::value_type; + + return exprm_ns::complex(in1) == + exprm_ns::complex(in2); + } + else { + return (in1 == in2); + } } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 003de44c27..453eb05c52 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace exp namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,8 @@ template struct ExpFunctor const realT y = std::imag(in); if (std::isfinite(x)) { if (std::isfinite(y)) { - return std::exp(in); + return exprm_ns::exp( + exprm_ns::complex(in)); // std::exp(in); } else { return resT{q_nan, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index 67ee23df48..b6b2f32e83 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -24,10 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace exp2 namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -76,7 +78,7 @@ template struct Exp2Functor const realT y = std::imag(tmp); if (std::isfinite(x)) { if (std::isfinite(y)) { - return std::exp(tmp); + return exprm_ns::exp(exprm_ns::complex(tmp)); } else { return resT{q_nan, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 3f6a73b6d3..f5204e87b3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -24,11 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp index e675407d0b..88a20dafe0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp @@ -23,10 +23,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index 241c0e7ca8..4ba335c98b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp index 2a151ce737..e01360efa7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp index 5704336990..f017b7f150 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp index a369c54f24..fd19d29c0b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 64da603037..bb1ff2ebcb 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -24,11 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index 1d8f177e40..1554f905b7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "utils/offset_utils.hpp" @@ -46,6 +46,7 @@ namespace isfinite namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index d9afdb9317..2720385614 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index b5051ab833..15551e295a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index c33d6d7c10..02c7a0d95a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp index 47e2301fe7..f9f6729968 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index abcc899fc0..ff37d87157 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace log namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -65,7 +67,13 @@ template struct LogFunctor resT operator()(const argT &in) const { - return std::log(in); + if constexpr (is_complex::value) { + using realT = typename argT::value_type; + return exprm_ns::log(exprm_ns::complex(in)); // std::log(in); + } + else { + return std::log(in); + } } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index afcf8aa085..88dabcaabe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -24,10 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace log10 namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -70,7 +72,9 @@ template struct Log10Functor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - return (std::log(in) / std::log(realT{10})); + // return (std::log(in) / std::log(realT{10})); + return exprm_ns::log(exprm_ns::complex(in)) / + std::log(realT{10}); } else { return std::log10(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index 6d7a56ccf5..11e3fb3f9f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 533d0120df..57d7dcaf31 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -24,10 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace log2 namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -70,7 +72,9 @@ template struct Log2Functor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - return std::log(in) / std::log(realT{2}); + // std::log(in) / std::log(realT{2}); + return exprm_ns::log(exprm_ns::complex(in)) / + std::log(realT{2}); } else { return std::log2(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index 6a187da6f4..e918454319 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -25,10 +25,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp index 10e4e0cbff..988d1ed380 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp index 78bacbe686..826af2ee37 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp index bfb1288870..333951e6b5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp index 44d361cfc1..ce4bde9e6b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp index 324f3f5ad2..8a1990ba7d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp index 9a7ec72e56..fb3490ee19 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/math_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index c316279a76..612ad78360 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -24,9 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include +#include #include #include "utils/offset_utils.hpp" @@ -49,6 +50,7 @@ namespace multiply namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct MultiplyFunctor { @@ -60,7 +62,18 @@ template struct MultiplyFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return in1 * in2; + if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using realT1 = typename argT1::value_type; + using realT2 = typename argT2::value_type; + + return exprm_ns::complex(in1) * + exprm_ns::complex(in2); + } + else { + return in1 * in2; + } } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp index cbeeb60b7c..bc28aafad7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp index 88e077b402..faeab82580 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp index cbeba2e91d..b3e109c76c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index 6654bae384..95e8442903 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -24,10 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "utils/offset_utils.hpp" @@ -50,6 +51,7 @@ namespace pow namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct PowFunctor { @@ -83,6 +85,15 @@ template struct PowFunctor } return res; } + else if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using realT1 = typename argT1::value_type; + using realT2 = typename argT2::value_type; + + return exprm_ns::pow(exprm_ns::complex(in1), + exprm_ns::complex(in2)); + } else { return std::pow(in1, in2); } @@ -350,11 +361,20 @@ template struct PowInplaceFunctor tmp1 *= tmp1; } res = res_tmp; - return; + } + else if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using r_resT = typename resT::value_type; + using r_argT = typename argT::value_type; + + res = exprm_ns::pow(exprm_ns::complex(res), + exprm_ns::complex(in)); } else { res = std::pow(res, in); - }; + } + return; } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index dcaa4b0f5f..92f5ffa729 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -24,12 +24,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 294b796e96..6a7580d548 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -24,11 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp index 051a1f9029..9e64e6500f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp @@ -25,9 +25,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 84d8fb7252..547d31b392 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -23,10 +23,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp index de51b31c30..d9e0c33081 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp @@ -25,12 +25,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index fc3d44dcfa..162db394de 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -23,11 +23,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace sign namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -61,38 +63,41 @@ template struct SignFunctor std::disjunction, is_complex>>; using supports_sg_loadstore = std::false_type; - resT operator()(const argT &x) const + resT operator()(const argT &in) const { if constexpr (std::is_integral_v) { if constexpr (std::is_unsigned_v) { - return resT(0 < x); + return resT(0 < in); } else { - return sign(x); + return sign_impl(in); } } else { if constexpr (is_complex::value) { - if (x == argT(0)) { + using realT = typename argT::value_type; + + if (in == argT(0)) { return resT(0); } else { - return (x / std::abs(x)); + auto z = exprm_ns::complex(in); + return (z / exprm_ns::abs(z)); } } else { - if (std::isnan(x)) { + if (std::isnan(in)) { return std::numeric_limits::quiet_NaN(); } else { - return sign(x); + return sign_impl(in); } } } } private: - template T sign(const T &v) const + template T sign_impl(const T &v) const { return (T(0) < v) - (v < T(0)); } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp index 0f509f7950..3e961c466d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp @@ -24,10 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index b9f03e6234..e1e9e79c57 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace sin namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -79,7 +81,8 @@ template struct SinFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { - return std::sin(in); + return exprm_ns::sin( + exprm_ns::complex(in)); // std::sin(in); } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index 3a8d05d774..b11c7402d0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -23,10 +23,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -47,6 +48,7 @@ namespace sinh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -79,7 +81,7 @@ template struct SinhFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return std::sinh(in); + return exprm_ns::sinh(exprm_ns::complex(in)); } /* * sinh(+-0 +- I Inf) = sign(d(+-0, dNaN))0 + I dNaN. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index efa580d70e..b638e4a55f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -24,12 +24,13 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -50,6 +51,7 @@ namespace sqrt namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -74,7 +76,10 @@ template struct SqrtFunctor // #else // return std::sqrt(in); // #endif - return csqrt(in); + using realT = typename argT::value_type; + + // return csqrt(in); + return exprm_ns::sqrt(exprm_ns::complex(in)); } else { return std::sqrt(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index 6b5f372c3d..2c37ce87d9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -24,10 +24,11 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace square namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -68,7 +70,16 @@ template struct SquareFunctor resT operator()(const argT &in) const { - return in * in; + if constexpr (is_complex::value) { + using realT = typename argT::value_type; + + auto z = exprm_ns::complex(in); + + return z * z; + } + else { + return in * in; + } } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index e4ae857738..9447873dec 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -24,9 +24,9 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index 45f931b7f4..1f97b59054 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -23,11 +23,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -48,6 +49,7 @@ namespace tan namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace cmplx_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -118,7 +120,7 @@ template struct TanFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return std::tan(in); + return cmplx_ns::tan(cmplx_ns::complex(in)); // std::tan(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index ef943319b2..453ce17b54 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -24,11 +24,12 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -49,6 +50,7 @@ namespace tanh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace cmplx_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -112,7 +114,8 @@ template struct TanhFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return std::tanh(in); + return cmplx_ns::tanh( + cmplx_ns::complex(in)); // std::tanh(in); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 86fb0ca2e2..6620d2e3c1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -24,9 +24,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include +#include +#include #include #include "utils/offset_utils.hpp" @@ -49,6 +50,7 @@ namespace true_divide namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct TrueDivideFunctor @@ -61,7 +63,32 @@ struct TrueDivideFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return in1 / in2; + if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using realT1 = typename argT1::value_type; + using realT2 = typename argT2::value_type; + + return exprm_ns::complex(in1) / + exprm_ns::complex(in2); + } + else if constexpr (tu_ns::is_complex::value && + !tu_ns::is_complex::value) + { + using realT1 = typename argT1::value_type; + + return exprm_ns::complex(in1) / in2; + } + else if constexpr (!tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using realT2 = typename argT2::value_type; + + return in1 / exprm_ns::complex(in2); + } + else { + return in1 / in2; + } } template @@ -381,7 +408,24 @@ template struct TrueDivideInplaceFunctor void operator()(resT &res, const argT &in) { - res /= in; + if constexpr (tu_ns::is_complex::value) { + using res_rT = typename resT::value_type; + if constexpr (tu_ns::is_complex::value) { + using arg_rT = typename argT::value_type; + + auto res1 = exprm_ns::complex(res); + res1 /= exprm_ns::complex(in); + res = res1; + } + else { + auto res1 = exprm_ns::complex(res); + res1 /= in; + res = res1; + } + } + else { + res /= in; + } } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp index 33e942dd6a..0e08d966e9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp @@ -23,10 +23,10 @@ //===---------------------------------------------------------------------===// #pragma once -#include #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp index 6acf0a9f50..769774f4dd 100644 --- a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp @@ -23,11 +23,11 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index adbf96be10..a8d9cf1972 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -23,11 +23,10 @@ //===----------------------------------------------------------------------===// #pragma once -#include -#include #include #include #include +#include #include #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/repeat.hpp b/dpctl/tensor/libtensor/include/kernels/repeat.hpp index 1f2335fc6c..05b57a8cda 100644 --- a/dpctl/tensor/libtensor/include/kernels/repeat.hpp +++ b/dpctl/tensor/libtensor/include/kernels/repeat.hpp @@ -23,11 +23,11 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include #include #include #include +#include #include #include "utils/offset_utils.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/where.hpp b/dpctl/tensor/libtensor/include/kernels/where.hpp index fc9546a9a8..9558603d5e 100644 --- a/dpctl/tensor/libtensor/include/kernels/where.hpp +++ b/dpctl/tensor/libtensor/include/kernels/where.hpp @@ -27,11 +27,11 @@ #include "pybind11/stl.h" #include "utils/offset_utils.hpp" #include "utils/type_utils.hpp" -#include #include #include #include #include +#include #include namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp index 29517ce2c5..523620737b 100644 --- a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp @@ -26,9 +26,9 @@ #pragma once -#include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index c0165b0ecc..b9b0ee08c0 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -23,9 +23,9 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp index afc458169e..af031a963b 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp @@ -25,8 +25,8 @@ #pragma once #include "dpctl4pybind11.hpp" -#include #include +#include namespace dpctl { diff --git a/dpctl/tensor/libtensor/include/utils/type_utils.hpp b/dpctl/tensor/libtensor/include/utils/type_utils.hpp index 4ea17ac730..a50e5159e4 100644 --- a/dpctl/tensor/libtensor/include/utils/type_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_utils.hpp @@ -23,9 +23,9 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include #include +#include #include namespace dpctl diff --git a/dpctl/tensor/libtensor/source/accumulators.cpp b/dpctl/tensor/libtensor/source/accumulators.cpp index 40f4424ef9..0a2ce69f69 100644 --- a/dpctl/tensor/libtensor/source/accumulators.cpp +++ b/dpctl/tensor/libtensor/source/accumulators.cpp @@ -23,11 +23,11 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/accumulators.hpp b/dpctl/tensor/libtensor/source/accumulators.hpp index 4979eab54f..ba40c38e1d 100644 --- a/dpctl/tensor/libtensor/source/accumulators.hpp +++ b/dpctl/tensor/libtensor/source/accumulators.hpp @@ -23,7 +23,7 @@ //===--------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp index ff7b32d0f7..903e1b5536 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp @@ -24,11 +24,11 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp index 26f1c6a646..8347d9f687 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp @@ -24,7 +24,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/boolean_reductions.cpp b/dpctl/tensor/libtensor/source/boolean_reductions.cpp index 5f3c1f5e51..32deab6da9 100644 --- a/dpctl/tensor/libtensor/source/boolean_reductions.cpp +++ b/dpctl/tensor/libtensor/source/boolean_reductions.cpp @@ -24,8 +24,8 @@ /// dpctl.tensor.all and dpctl.tensor.any //===----------------------------------------------------------------------===// -#include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/boolean_reductions.hpp b/dpctl/tensor/libtensor/source/boolean_reductions.hpp index 5a0d5d381a..4d59463f8b 100644 --- a/dpctl/tensor/libtensor/source/boolean_reductions.hpp +++ b/dpctl/tensor/libtensor/source/boolean_reductions.hpp @@ -25,11 +25,11 @@ #pragma once #include "dpctl4pybind11.hpp" -#include #include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp index 290ab88fe8..51ddd81312 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp @@ -22,7 +22,6 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===----------------------------------------------------------------------===// -#include #include #include #include @@ -30,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp index c2161f1ba6..c8196b416a 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp index c9ab58528a..235878b820 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp @@ -22,7 +22,7 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===----------------------------------------------------------------------===// -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.hpp b/dpctl/tensor/libtensor/source/copy_for_reshape.hpp index 2f25a68480..cd4ca68ff0 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.hpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.hpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_for_roll.cpp b/dpctl/tensor/libtensor/source/copy_for_roll.cpp index cc319e6e08..ab36f543af 100644 --- a/dpctl/tensor/libtensor/source/copy_for_roll.cpp +++ b/dpctl/tensor/libtensor/source/copy_for_roll.cpp @@ -22,7 +22,7 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===----------------------------------------------------------------------===// -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_for_roll.hpp b/dpctl/tensor/libtensor/source/copy_for_roll.hpp index 38e84b9c6a..357d821eff 100644 --- a/dpctl/tensor/libtensor/source/copy_for_roll.hpp +++ b/dpctl/tensor/libtensor/source/copy_for_roll.hpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp index bb367a42b9..f644522c18 100644 --- a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp +++ b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp @@ -22,8 +22,8 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===----------------------------------------------------------------------===// -#include #include +#include #include #include "dpctl4pybind11.hpp" diff --git a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp index 3f1833ec99..247a5d7314 100644 --- a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp +++ b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp @@ -23,7 +23,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include "dpctl4pybind11.hpp" diff --git a/dpctl/tensor/libtensor/source/device_support_queries.cpp b/dpctl/tensor/libtensor/source/device_support_queries.cpp index 9f793cb00a..cb0dbc02a5 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.cpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.cpp @@ -25,9 +25,9 @@ #include #include "dpctl4pybind11.hpp" -#include #include #include +#include namespace dpctl { diff --git a/dpctl/tensor/libtensor/source/device_support_queries.hpp b/dpctl/tensor/libtensor/source/device_support_queries.hpp index 3367f8bfc2..efffd4ac93 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.hpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.hpp @@ -26,9 +26,9 @@ #include #include "dpctl4pybind11.hpp" -#include #include #include +#include namespace dpctl { diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp new file mode 100644 index 0000000000..9ab7c0807c --- /dev/null +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -0,0 +1,5155 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions, +/// specifically functions for elementwise operations. +//===----------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include + +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/abs.hpp" +#include "kernels/elementwise_functions/acos.hpp" +#include "kernels/elementwise_functions/acosh.hpp" +#include "kernels/elementwise_functions/add.hpp" +#include "kernels/elementwise_functions/asin.hpp" +#include "kernels/elementwise_functions/asinh.hpp" +#include "kernels/elementwise_functions/atan.hpp" +#include "kernels/elementwise_functions/atan2.hpp" +#include "kernels/elementwise_functions/atanh.hpp" +#include "kernels/elementwise_functions/bitwise_and.hpp" +#include "kernels/elementwise_functions/bitwise_invert.hpp" +#include "kernels/elementwise_functions/bitwise_left_shift.hpp" +#include "kernels/elementwise_functions/bitwise_or.hpp" +#include "kernels/elementwise_functions/bitwise_right_shift.hpp" +#include "kernels/elementwise_functions/bitwise_xor.hpp" +#include "kernels/elementwise_functions/cbrt.hpp" +#include "kernels/elementwise_functions/ceil.hpp" +#include "kernels/elementwise_functions/conj.hpp" +#include "kernels/elementwise_functions/copysign.hpp" +#include "kernels/elementwise_functions/cos.hpp" +#include "kernels/elementwise_functions/cosh.hpp" +#include "kernels/elementwise_functions/equal.hpp" +#include "kernels/elementwise_functions/exp.hpp" +#include "kernels/elementwise_functions/exp2.hpp" +#include "kernels/elementwise_functions/expm1.hpp" +#include "kernels/elementwise_functions/floor.hpp" +#include "kernels/elementwise_functions/floor_divide.hpp" +#include "kernels/elementwise_functions/greater.hpp" +#include "kernels/elementwise_functions/greater_equal.hpp" +#include "kernels/elementwise_functions/hypot.hpp" +#include "kernels/elementwise_functions/imag.hpp" +#include "kernels/elementwise_functions/isfinite.hpp" +#include "kernels/elementwise_functions/isinf.hpp" +#include "kernels/elementwise_functions/isnan.hpp" +#include "kernels/elementwise_functions/less.hpp" +#include "kernels/elementwise_functions/less_equal.hpp" +#include "kernels/elementwise_functions/log.hpp" +#include "kernels/elementwise_functions/log10.hpp" +#include "kernels/elementwise_functions/log1p.hpp" +#include "kernels/elementwise_functions/log2.hpp" +#include "kernels/elementwise_functions/logaddexp.hpp" +#include "kernels/elementwise_functions/logical_and.hpp" +#include "kernels/elementwise_functions/logical_not.hpp" +#include "kernels/elementwise_functions/logical_or.hpp" +#include "kernels/elementwise_functions/logical_xor.hpp" +#include "kernels/elementwise_functions/maximum.hpp" +#include "kernels/elementwise_functions/minimum.hpp" +#include "kernels/elementwise_functions/multiply.hpp" +#include "kernels/elementwise_functions/negative.hpp" +#include "kernels/elementwise_functions/not_equal.hpp" +#include "kernels/elementwise_functions/positive.hpp" +#include "kernels/elementwise_functions/pow.hpp" +#include "kernels/elementwise_functions/proj.hpp" +#include "kernels/elementwise_functions/real.hpp" +#include "kernels/elementwise_functions/remainder.hpp" +#include "kernels/elementwise_functions/round.hpp" +#include "kernels/elementwise_functions/rsqrt.hpp" +#include "kernels/elementwise_functions/sign.hpp" +#include "kernels/elementwise_functions/signbit.hpp" +#include "kernels/elementwise_functions/sin.hpp" +#include "kernels/elementwise_functions/sinh.hpp" +#include "kernels/elementwise_functions/sqrt.hpp" +#include "kernels/elementwise_functions/square.hpp" +#include "kernels/elementwise_functions/subtract.hpp" +#include "kernels/elementwise_functions/tan.hpp" +#include "kernels/elementwise_functions/tanh.hpp" +#include "kernels/elementwise_functions/true_divide.hpp" +#include "kernels/elementwise_functions/trunc.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t) +{ + switch (dst_typenum_t) { + case td_ns::typenum_t::BOOL: + return py::dtype("?"); + case td_ns::typenum_t::INT8: + return py::dtype("i1"); + case td_ns::typenum_t::UINT8: + return py::dtype("u1"); + case td_ns::typenum_t::INT16: + return py::dtype("i2"); + case td_ns::typenum_t::UINT16: + return py::dtype("u2"); + case td_ns::typenum_t::INT32: + return py::dtype("i4"); + case td_ns::typenum_t::UINT32: + return py::dtype("u4"); + case td_ns::typenum_t::INT64: + return py::dtype("i8"); + case td_ns::typenum_t::UINT64: + return py::dtype("u8"); + case td_ns::typenum_t::HALF: + return py::dtype("f2"); + case td_ns::typenum_t::FLOAT: + return py::dtype("f4"); + case td_ns::typenum_t::DOUBLE: + return py::dtype("f8"); + case td_ns::typenum_t::CFLOAT: + return py::dtype("c8"); + case td_ns::typenum_t::CDOUBLE: + return py::dtype("c16"); + default: + throw py::value_error("Unrecognized dst_typeid"); + } +} + +int _result_typeid(int arg_typeid, const int *fn_output_id) +{ + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) { + throw py::value_error("Input typeid " + std::to_string(arg_typeid) + + " is outside of expected bounds."); + } + + return fn_output_id[arg_typeid]; +} + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; +using ew_cmn_ns::unary_contig_impl_fn_ptr_t; +using ew_cmn_ns::unary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// U01: ==== ABS (x) +namespace impl +{ + +namespace abs_fn_ns = dpctl::tensor::kernels::abs; + +static unary_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; +static int abs_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + abs_strided_dispatch_vector[td_ns::num_types]; + +void populate_abs_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = abs_fn_ns; + + using fn_ns::AbsContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); + + using fn_ns::AbsStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); + + using fn_ns::AbsTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(abs_output_typeid_vector); +}; + +} // namespace impl + +// U02: ==== ACOS (x) +namespace impl +{ + +namespace acos_fn_ns = dpctl::tensor::kernels::acos; + +static unary_contig_impl_fn_ptr_t acos_contig_dispatch_vector[td_ns::num_types]; +static int acos_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + acos_strided_dispatch_vector[td_ns::num_types]; + +void populate_acos_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = acos_fn_ns; + + using fn_ns::AcosContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(acos_contig_dispatch_vector); + + using fn_ns::AcosStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(acos_strided_dispatch_vector); + + using fn_ns::AcosTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(acos_output_typeid_vector); +} + +} // namespace impl + +// U03: ===== ACOSH (x) +namespace impl +{ + +namespace acosh_fn_ns = dpctl::tensor::kernels::acosh; + +static unary_contig_impl_fn_ptr_t + acosh_contig_dispatch_vector[td_ns::num_types]; +static int acosh_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + acosh_strided_dispatch_vector[td_ns::num_types]; + +void populate_acosh_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = acosh_fn_ns; + + using fn_ns::AcoshContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(acosh_contig_dispatch_vector); + + using fn_ns::AcoshStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(acosh_strided_dispatch_vector); + + using fn_ns::AcoshTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(acosh_output_typeid_vector); +} + +} // namespace impl + +// B01: ===== ADD (x1, x2) +namespace impl +{ +namespace add_fn_ns = dpctl::tensor::kernels::add; + +static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int add_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// add(matrix, row) +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +// add(row, matrix) +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + add_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + add_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t + add_inplace_row_matrix_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_add_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = add_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::AddTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(add_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::AddStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(add_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::AddContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(add_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::AddContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + AddContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + add_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::AddContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + AddContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + add_contig_row_contig_matrix_broadcast_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::AddInplaceStridedFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(add_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::AddInplaceContigFactory; + DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(add_inplace_contig_dispatch_table); + + // function pointers for inplace operation on contiguous matrix + // and contiguous row + using fn_ns::AddInplaceRowMatrixBroadcastFactory; + DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); +}; + +} // namespace impl + +// U04: ===== ASIN (x) +namespace impl +{ + +namespace asin_fn_ns = dpctl::tensor::kernels::asin; + +static unary_contig_impl_fn_ptr_t asin_contig_dispatch_vector[td_ns::num_types]; +static int asin_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + asin_strided_dispatch_vector[td_ns::num_types]; + +void populate_asin_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = asin_fn_ns; + + using fn_ns::AsinContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(asin_contig_dispatch_vector); + + using fn_ns::AsinStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(asin_strided_dispatch_vector); + + using fn_ns::AsinTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(asin_output_typeid_vector); +} + +} // namespace impl + +// U05: ===== ASINH (x) +namespace impl +{ + +namespace asinh_fn_ns = dpctl::tensor::kernels::asinh; + +static unary_contig_impl_fn_ptr_t + asinh_contig_dispatch_vector[td_ns::num_types]; +static int asinh_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + asinh_strided_dispatch_vector[td_ns::num_types]; + +void populate_asinh_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = asinh_fn_ns; + + using fn_ns::AsinhContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(asinh_contig_dispatch_vector); + + using fn_ns::AsinhStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(asinh_strided_dispatch_vector); + + using fn_ns::AsinhTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(asinh_output_typeid_vector); +} + +} // namespace impl + +// U06: ===== ATAN (x) +namespace impl +{ + +namespace atan_fn_ns = dpctl::tensor::kernels::atan; + +static unary_contig_impl_fn_ptr_t atan_contig_dispatch_vector[td_ns::num_types]; +static int atan_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + atan_strided_dispatch_vector[td_ns::num_types]; + +void populate_atan_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = atan_fn_ns; + + using fn_ns::AtanContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(atan_contig_dispatch_vector); + + using fn_ns::AtanStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(atan_strided_dispatch_vector); + + using fn_ns::AtanTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(atan_output_typeid_vector); +} + +} // namespace impl + +// B02: ===== ATAN2 (x1, x2) +namespace impl +{ +namespace atan2_fn_ns = dpctl::tensor::kernels::atan2; + +static binary_contig_impl_fn_ptr_t + atan2_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int atan2_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + atan2_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_atan2_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = atan2_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::Atan2TypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(atan2_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::Atan2StridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(atan2_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::Atan2ContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(atan2_contig_dispatch_table); +}; + +} // namespace impl + +// U07: ===== ATANH (x) +namespace impl +{ + +namespace atanh_fn_ns = dpctl::tensor::kernels::atanh; + +static unary_contig_impl_fn_ptr_t + atanh_contig_dispatch_vector[td_ns::num_types]; +static int atanh_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + atanh_strided_dispatch_vector[td_ns::num_types]; + +void populate_atanh_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = atanh_fn_ns; + + using fn_ns::AtanhContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(atanh_contig_dispatch_vector); + + using fn_ns::AtanhStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(atanh_strided_dispatch_vector); + + using fn_ns::AtanhTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(atanh_output_typeid_vector); +} + +} // namespace impl + +// B03: ===== BITWISE_AND (x1, x2) +namespace impl +{ +namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; + +static binary_contig_impl_fn_ptr_t + bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_bitwise_and_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_and_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseAndTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_and_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseAndStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_and_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseAndContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_and_contig_dispatch_table); +}; + +} // namespace impl + +// B04: ===== BITWISE_LEFT_SHIFT (x1, x2) +namespace impl +{ +namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; + +static binary_contig_impl_fn_ptr_t + bitwise_left_shift_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int bitwise_left_shift_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_left_shift_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_left_shift_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_left_shift_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseLeftShiftTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_left_shift_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseLeftShiftStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_left_shift_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseLeftShiftContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_left_shift_contig_dispatch_table); +}; + +} // namespace impl + +// U08: ===== BITWISE_INVERT (x) +namespace impl +{ + +namespace bitwise_invert_fn_ns = dpctl::tensor::kernels::bitwise_invert; + +static unary_contig_impl_fn_ptr_t + bitwise_invert_contig_dispatch_vector[td_ns::num_types]; +static int bitwise_invert_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + bitwise_invert_strided_dispatch_vector[td_ns::num_types]; + +void populate_bitwise_invert_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_invert_fn_ns; + + using fn_ns::BitwiseInvertContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(bitwise_invert_contig_dispatch_vector); + + using fn_ns::BitwiseInvertStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(bitwise_invert_strided_dispatch_vector); + + using fn_ns::BitwiseInvertTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(bitwise_invert_output_typeid_vector); +}; + +} // namespace impl + +// B05: ===== BITWISE_OR (x1, x2) +namespace impl +{ +namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; + +static binary_contig_impl_fn_ptr_t + bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_bitwise_or_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_or_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseOrTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_or_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseOrStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_or_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseOrContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_or_contig_dispatch_table); +}; +} // namespace impl + +// B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) +namespace impl +{ +namespace bitwise_right_shift_fn_ns = + dpctl::tensor::kernels::bitwise_right_shift; + +static binary_contig_impl_fn_ptr_t + bitwise_right_shift_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int bitwise_right_shift_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_right_shift_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_right_shift_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_right_shift_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseRightShiftTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_right_shift_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseRightShiftStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_right_shift_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseRightShiftContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_right_shift_contig_dispatch_table); +}; + +} // namespace impl + +// B07: ===== BITWISE_XOR (x1, x2) +namespace impl +{ +namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; + +static binary_contig_impl_fn_ptr_t + bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_bitwise_xor_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_xor_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseXorTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_xor_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseXorStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_xor_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseXorContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_xor_contig_dispatch_table); +}; +} // namespace impl + +// U09: ==== CEIL (x) +namespace impl +{ + +namespace ceil_fn_ns = dpctl::tensor::kernels::ceil; + +static unary_contig_impl_fn_ptr_t ceil_contig_dispatch_vector[td_ns::num_types]; +static int ceil_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + ceil_strided_dispatch_vector[td_ns::num_types]; + +void populate_ceil_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = ceil_fn_ns; + + using fn_ns::CeilContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(ceil_contig_dispatch_vector); + + using fn_ns::CeilStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(ceil_strided_dispatch_vector); + + using fn_ns::CeilTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(ceil_output_typeid_vector); +} + +} // namespace impl + +// U10: ==== CONJ (x) +namespace impl +{ + +namespace conj_fn_ns = dpctl::tensor::kernels::conj; + +static unary_contig_impl_fn_ptr_t conj_contig_dispatch_vector[td_ns::num_types]; +static int conj_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + conj_strided_dispatch_vector[td_ns::num_types]; + +void populate_conj_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = conj_fn_ns; + + using fn_ns::ConjContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(conj_contig_dispatch_vector); + + using fn_ns::ConjStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(conj_strided_dispatch_vector); + + using fn_ns::ConjTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(conj_output_typeid_vector); +} +} // namespace impl + +// U11: ==== COS (x) +namespace impl +{ + +namespace cos_fn_ns = dpctl::tensor::kernels::cos; + +static unary_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; +static int cos_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + cos_strided_dispatch_vector[td_ns::num_types]; + +void populate_cos_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = cos_fn_ns; + + using fn_ns::CosContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); + + using fn_ns::CosStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); + + using fn_ns::CosTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(cos_output_typeid_vector); +} + +} // namespace impl + +// U12: ==== COSH (x) +namespace impl +{ + +namespace cosh_fn_ns = dpctl::tensor::kernels::cosh; + +static unary_contig_impl_fn_ptr_t cosh_contig_dispatch_vector[td_ns::num_types]; +static int cosh_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + cosh_strided_dispatch_vector[td_ns::num_types]; + +void populate_cosh_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = cosh_fn_ns; + + using fn_ns::CoshContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(cosh_contig_dispatch_vector); + + using fn_ns::CoshStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(cosh_strided_dispatch_vector); + + using fn_ns::CoshTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(cosh_output_typeid_vector); +} + +} // namespace impl + +// B08: ==== DIVIDE (x1, x2) +namespace impl +{ +namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; + +static binary_contig_impl_fn_ptr_t + true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; +static int true_divide_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// divide(matrix, row) +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + true_divide_contig_matrix_contig_row_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +// divide(row, matrix) +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + true_divide_contig_row_contig_matrix_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + true_divide_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + true_divide_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t + true_divide_inplace_row_matrix_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_true_divide_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = true_divide_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::TrueDivideTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(true_divide_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::TrueDivideStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::TrueDivideContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + TrueDivideContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + true_divide_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + TrueDivideContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + true_divide_contig_row_contig_matrix_broadcast_dispatch_table); + + // which input types are supported, and what is the type of the result + using fn_ns::TrueDivideInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::TrueDivideInplaceStridedFactory; + DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(true_divide_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::TrueDivideInplaceContigFactory; + DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(true_divide_inplace_contig_dispatch_table); + + // function pointers for inplace operation on contiguous matrix + // and contiguous row + using fn_ns::TrueDivideInplaceRowMatrixBroadcastFactory; + DispatchTableBuilder + dtb9; + dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table); +}; + +} // namespace impl + +// B09: ==== EQUAL (x1, x2) +namespace impl +{ +namespace equal_fn_ns = dpctl::tensor::kernels::equal; + +static binary_contig_impl_fn_ptr_t + equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int equal_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_equal_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = equal_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::EqualTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(equal_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::EqualStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(equal_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::EqualContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(equal_contig_dispatch_table); +}; +} // namespace impl + +// U13: ==== EXP (x) +namespace impl +{ + +namespace exp_fn_ns = dpctl::tensor::kernels::exp; + +static unary_contig_impl_fn_ptr_t exp_contig_dispatch_vector[td_ns::num_types]; +static int exp_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + exp_strided_dispatch_vector[td_ns::num_types]; + +void populate_exp_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = exp_fn_ns; + + using fn_ns::ExpContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(exp_contig_dispatch_vector); + + using fn_ns::ExpStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(exp_strided_dispatch_vector); + + using fn_ns::ExpTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(exp_output_typeid_vector); +} + +} // namespace impl + +// U14: ==== EXPM1 (x) +namespace impl +{ + +namespace expm1_fn_ns = dpctl::tensor::kernels::expm1; + +static unary_contig_impl_fn_ptr_t + expm1_contig_dispatch_vector[td_ns::num_types]; +static int expm1_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + expm1_strided_dispatch_vector[td_ns::num_types]; + +void populate_expm1_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = expm1_fn_ns; + + using fn_ns::Expm1ContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(expm1_contig_dispatch_vector); + + using fn_ns::Expm1StridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(expm1_strided_dispatch_vector); + + using fn_ns::Expm1TypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(expm1_output_typeid_vector); +} + +} // namespace impl + +// U15: ==== FLOOR (x) +namespace impl +{ + +namespace floor_fn_ns = dpctl::tensor::kernels::floor; + +static unary_contig_impl_fn_ptr_t + floor_contig_dispatch_vector[td_ns::num_types]; +static int floor_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + floor_strided_dispatch_vector[td_ns::num_types]; + +void populate_floor_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = floor_fn_ns; + + using fn_ns::FloorContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(floor_contig_dispatch_vector); + + using fn_ns::FloorStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(floor_strided_dispatch_vector); + + using fn_ns::FloorTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(floor_output_typeid_vector); +} + +} // namespace impl + +// B10: ==== FLOOR_DIVIDE (x1, x2) +namespace impl +{ +namespace floor_divide_fn_ns = dpctl::tensor::kernels::floor_divide; + +static binary_contig_impl_fn_ptr_t + floor_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + floor_divide_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + floor_divide_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_floor_divide_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = floor_divide_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::FloorDivideTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(floor_divide_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::FloorDivideStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(floor_divide_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::FloorDivideContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(floor_divide_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::FloorDivideInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(floor_divide_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::FloorDivideInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); +}; + +} // namespace impl + +// B11: ==== GREATER (x1, x2) +namespace impl +{ +namespace greater_fn_ns = dpctl::tensor::kernels::greater; + +static binary_contig_impl_fn_ptr_t + greater_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int greater_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + greater_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_greater_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = greater_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::GreaterTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(greater_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::GreaterStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(greater_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::GreaterContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(greater_contig_dispatch_table); +}; +} // namespace impl + +// B12: ==== GREATER_EQUAL (x1, x2) +namespace impl +{ +namespace greater_equal_fn_ns = dpctl::tensor::kernels::greater_equal; + +static binary_contig_impl_fn_ptr_t + greater_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int greater_equal_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + greater_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_greater_equal_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = greater_equal_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::GreaterEqualTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(greater_equal_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::GreaterEqualStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(greater_equal_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::GreaterEqualContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(greater_equal_contig_dispatch_table); +}; +} // namespace impl + +// U16: ==== IMAG (x) +namespace impl +{ + +namespace imag_fn_ns = dpctl::tensor::kernels::imag; + +static unary_contig_impl_fn_ptr_t imag_contig_dispatch_vector[td_ns::num_types]; +static int imag_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + imag_strided_dispatch_vector[td_ns::num_types]; + +void populate_imag_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = imag_fn_ns; + + using fn_ns::ImagContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(imag_contig_dispatch_vector); + + using fn_ns::ImagStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(imag_strided_dispatch_vector); + + using fn_ns::ImagTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(imag_output_typeid_vector); +} +} // namespace impl + +// U17: ==== ISFINITE (x) +namespace impl +{ +namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; + +static unary_contig_impl_fn_ptr_t + isfinite_contig_dispatch_vector[td_ns::num_types]; +static int isfinite_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + isfinite_strided_dispatch_vector[td_ns::num_types]; + +void populate_isfinite_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = isfinite_fn_ns; + + using fn_ns::IsFiniteContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isfinite_contig_dispatch_vector); + + using fn_ns::IsFiniteStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isfinite_strided_dispatch_vector); + + using fn_ns::IsFiniteTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(isfinite_output_typeid_vector); +} + +} // namespace impl + +// U18: ==== ISINF (x) +namespace impl +{ +namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; + +static unary_contig_impl_fn_ptr_t + isinf_contig_dispatch_vector[td_ns::num_types]; +static int isinf_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + isinf_strided_dispatch_vector[td_ns::num_types]; + +void populate_isinf_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = isinf_fn_ns; + + using fn_ns::IsInfContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isinf_contig_dispatch_vector); + + using fn_ns::IsInfStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isinf_strided_dispatch_vector); + + using fn_ns::IsInfTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(isinf_output_typeid_vector); +} + +} // namespace impl + +// U19: ==== ISNAN (x) +namespace impl +{ +namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; + +static unary_contig_impl_fn_ptr_t + isnan_contig_dispatch_vector[td_ns::num_types]; +static int isnan_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + isnan_strided_dispatch_vector[td_ns::num_types]; + +void populate_isnan_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = isnan_fn_ns; + + using fn_ns::IsNanContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); + + using fn_ns::IsNanStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); + + using fn_ns::IsNanTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(isnan_output_typeid_vector); +} + +} // namespace impl + +// B13: ==== LESS (x1, x2) +namespace impl +{ +namespace less_fn_ns = dpctl::tensor::kernels::less; + +static binary_contig_impl_fn_ptr_t less_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int less_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + less_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_less_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = less_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LessTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(less_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LessStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(less_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LessContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(less_contig_dispatch_table); +}; +} // namespace impl + +// B14: ==== LESS_EQUAL (x1, x2) +namespace impl +{ +namespace less_equal_fn_ns = dpctl::tensor::kernels::less_equal; + +static binary_contig_impl_fn_ptr_t + less_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int less_equal_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + less_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_less_equal_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = less_equal_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LessEqualTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(less_equal_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LessEqualStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(less_equal_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LessEqualContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(less_equal_contig_dispatch_table); +}; +} // namespace impl + +// U20: ==== LOG (x) +namespace impl +{ + +namespace log_fn_ns = dpctl::tensor::kernels::log; + +static unary_contig_impl_fn_ptr_t log_contig_dispatch_vector[td_ns::num_types]; +static int log_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + log_strided_dispatch_vector[td_ns::num_types]; + +void populate_log_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = log_fn_ns; + + using fn_ns::LogContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(log_contig_dispatch_vector); + + using fn_ns::LogStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(log_strided_dispatch_vector); + + using fn_ns::LogTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(log_output_typeid_vector); +} + +} // namespace impl + +// U21: ==== LOG1P (x) +namespace impl +{ + +namespace log1p_fn_ns = dpctl::tensor::kernels::log1p; + +static unary_contig_impl_fn_ptr_t + log1p_contig_dispatch_vector[td_ns::num_types]; +static int log1p_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + log1p_strided_dispatch_vector[td_ns::num_types]; + +void populate_log1p_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = log1p_fn_ns; + + using fn_ns::Log1pContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(log1p_contig_dispatch_vector); + + using fn_ns::Log1pStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(log1p_strided_dispatch_vector); + + using fn_ns::Log1pTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(log1p_output_typeid_vector); +} + +} // namespace impl + +// U22: ==== LOG2 (x) +namespace impl +{ + +namespace log2_fn_ns = dpctl::tensor::kernels::log2; + +static unary_contig_impl_fn_ptr_t log2_contig_dispatch_vector[td_ns::num_types]; +static int log2_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + log2_strided_dispatch_vector[td_ns::num_types]; + +void populate_log2_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = log2_fn_ns; + + using fn_ns::Log2ContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(log2_contig_dispatch_vector); + + using fn_ns::Log2StridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(log2_strided_dispatch_vector); + + using fn_ns::Log2TypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(log2_output_typeid_vector); +}; + +} // namespace impl + +// U23: ==== LOG10 (x) +namespace impl +{ + +namespace log10_fn_ns = dpctl::tensor::kernels::log10; + +static unary_contig_impl_fn_ptr_t + log10_contig_dispatch_vector[td_ns::num_types]; +static int log10_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + log10_strided_dispatch_vector[td_ns::num_types]; + +void populate_log10_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = log10_fn_ns; + + using fn_ns::Log10ContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(log10_contig_dispatch_vector); + + using fn_ns::Log10StridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(log10_strided_dispatch_vector); + + using fn_ns::Log10TypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(log10_output_typeid_vector); +}; + +} // namespace impl + +// B15: ==== LOGADDEXP (x1, x2) +namespace impl +{ +namespace logaddexp_fn_ns = dpctl::tensor::kernels::logaddexp; + +static binary_contig_impl_fn_ptr_t + logaddexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int logaddexp_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + logaddexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_logaddexp_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = logaddexp_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LogAddExpTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(logaddexp_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LogAddExpStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(logaddexp_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LogAddExpContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(logaddexp_contig_dispatch_table); +}; +} // namespace impl + +// B16: ==== LOGICAL_AND (x1, x2) +namespace impl +{ +namespace logical_and_fn_ns = dpctl::tensor::kernels::logical_and; + +static binary_contig_impl_fn_ptr_t + logical_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int logical_and_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + logical_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_logical_and_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = logical_and_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LogicalAndTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(logical_and_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LogicalAndStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(logical_and_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LogicalAndContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(logical_and_contig_dispatch_table); +}; +} // namespace impl + +// U24: ==== LOGICAL_NOT (x) +namespace impl +{ +namespace logical_not_fn_ns = dpctl::tensor::kernels::logical_not; + +static unary_contig_impl_fn_ptr_t + logical_not_contig_dispatch_vector[td_ns::num_types]; +static int logical_not_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + logical_not_strided_dispatch_vector[td_ns::num_types]; + +void populate_logical_not_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = logical_not_fn_ns; + + using fn_ns::LogicalNotContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(logical_not_contig_dispatch_vector); + + using fn_ns::LogicalNotStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(logical_not_strided_dispatch_vector); + + using fn_ns::LogicalNotTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(logical_not_output_typeid_vector); +}; +} // namespace impl + +// B17: ==== LOGICAL_OR (x1, x2) +namespace impl +{ +namespace logical_or_fn_ns = dpctl::tensor::kernels::logical_or; + +static binary_contig_impl_fn_ptr_t + logical_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int logical_or_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + logical_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_logical_or_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = logical_or_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LogicalOrTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(logical_or_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LogicalOrStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(logical_or_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LogicalOrContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(logical_or_contig_dispatch_table); +}; +} // namespace impl + +// B18: ==== LOGICAL_XOR (x1, x2) +namespace impl +{ +namespace logical_xor_fn_ns = dpctl::tensor::kernels::logical_xor; + +static binary_contig_impl_fn_ptr_t + logical_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int logical_xor_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + logical_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_logical_xor_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = logical_xor_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::LogicalXorTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(logical_xor_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::LogicalXorStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(logical_xor_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::LogicalXorContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(logical_xor_contig_dispatch_table); +}; +} // namespace impl + +// B??: ==== MAXIMUM (x1, x2) +namespace impl +{ + +namespace maximum_fn_ns = dpctl::tensor::kernels::maximum; + +static binary_contig_impl_fn_ptr_t + maximum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int maximum_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + maximum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_maximum_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = maximum_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::MaximumTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(maximum_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::MaximumStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(maximum_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::MaximumContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(maximum_contig_dispatch_table); +}; + +} // namespace impl + +// B??: ==== MINIMUM (x1, x2) +namespace impl +{ + +namespace minimum_fn_ns = dpctl::tensor::kernels::minimum; + +static binary_contig_impl_fn_ptr_t + minimum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int minimum_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + minimum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_minimum_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = minimum_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::MinimumTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(minimum_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::MinimumStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(minimum_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::MinimumContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(minimum_contig_dispatch_table); +}; + +} // namespace impl + +// B19: ==== MULTIPLY (x1, x2) +namespace impl +{ + +namespace multiply_fn_ns = dpctl::tensor::kernels::multiply; + +static binary_contig_impl_fn_ptr_t + multiply_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int multiply_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + multiply_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// mul(matrix, row) +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + multiply_contig_matrix_contig_row_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +// mul(row, matrix) +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + multiply_contig_row_contig_matrix_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + multiply_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + multiply_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t + multiply_inplace_row_matrix_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_multiply_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = multiply_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::MultiplyTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(multiply_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::MultiplyStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(multiply_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::MultiplyContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(multiply_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::MultiplyContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + MultiplyContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + multiply_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::MultiplyContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + MultiplyContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + multiply_contig_row_contig_matrix_broadcast_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::MultiplyInplaceStridedFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(multiply_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::MultiplyInplaceContigFactory; + DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(multiply_inplace_contig_dispatch_table); + + // function pointers for inplace operation on contiguous matrix + // and contiguous row + using fn_ns::MultiplyInplaceRowMatrixBroadcastFactory; + DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(multiply_inplace_row_matrix_dispatch_table); +}; + +} // namespace impl + +// U25: ==== NEGATIVE (x) +namespace impl +{ + +namespace negative_fn_ns = dpctl::tensor::kernels::negative; + +static unary_contig_impl_fn_ptr_t + negative_contig_dispatch_vector[td_ns::num_types]; +static int negative_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + negative_strided_dispatch_vector[td_ns::num_types]; + +void populate_negative_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = negative_fn_ns; + + using fn_ns::NegativeContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(negative_contig_dispatch_vector); + + using fn_ns::NegativeStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(negative_strided_dispatch_vector); + + using fn_ns::NegativeTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(negative_output_typeid_vector); +} + +} // namespace impl + +// B20: ==== NOT_EQUAL (x1, x2) +namespace impl +{ +namespace not_equal_fn_ns = dpctl::tensor::kernels::not_equal; + +static binary_contig_impl_fn_ptr_t + not_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int not_equal_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + not_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_not_equal_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = not_equal_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::NotEqualTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(not_equal_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::NotEqualStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(not_equal_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::NotEqualContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(not_equal_contig_dispatch_table); +}; +} // namespace impl + +// U26: ==== POSITIVE (x) +namespace impl +{ + +namespace positive_fn_ns = dpctl::tensor::kernels::positive; + +static unary_contig_impl_fn_ptr_t + positive_contig_dispatch_vector[td_ns::num_types]; +static int positive_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + positive_strided_dispatch_vector[td_ns::num_types]; + +void populate_positive_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = positive_fn_ns; + + using fn_ns::PositiveContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(positive_contig_dispatch_vector); + + using fn_ns::PositiveStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(positive_strided_dispatch_vector); + + using fn_ns::PositiveTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(positive_output_typeid_vector); +} + +} // namespace impl + +// B21: ==== POW (x1, x2) +namespace impl +{ + +namespace pow_fn_ns = dpctl::tensor::kernels::pow; + +static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_pow_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = pow_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::PowTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(pow_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::PowStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(pow_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::PowContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(pow_contig_dispatch_table); +}; + +} // namespace impl + +// U??: ==== PROJ (x) +namespace impl +{ + +namespace proj_fn_ns = dpctl::tensor::kernels::proj; + +static unary_contig_impl_fn_ptr_t proj_contig_dispatch_vector[td_ns::num_types]; +static int proj_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + proj_strided_dispatch_vector[td_ns::num_types]; + +void populate_proj_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = proj_fn_ns; + + using fn_ns::ProjContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(proj_contig_dispatch_vector); + + using fn_ns::ProjStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(proj_strided_dispatch_vector); + + using fn_ns::ProjTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(proj_output_typeid_vector); +} +} // namespace impl + +// U27: ==== REAL (x) +namespace impl +{ + +namespace real_fn_ns = dpctl::tensor::kernels::real; + +static unary_contig_impl_fn_ptr_t real_contig_dispatch_vector[td_ns::num_types]; +static int real_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + real_strided_dispatch_vector[td_ns::num_types]; + +void populate_real_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = real_fn_ns; + + using fn_ns::RealContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(real_contig_dispatch_vector); + + using fn_ns::RealStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(real_strided_dispatch_vector); + + using fn_ns::RealTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(real_output_typeid_vector); +} +} // namespace impl + +// B22: ==== REMAINDER (x1, x2) +namespace impl +{ + +namespace remainder_fn_ns = dpctl::tensor::kernels::remainder; + +static binary_contig_impl_fn_ptr_t + remainder_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int remainder_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + remainder_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_remainder_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = remainder_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::RemainderTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(remainder_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::RemainderStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(remainder_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::RemainderContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(remainder_contig_dispatch_table); +} + +} // namespace impl + +// U28: ==== ROUND (x) +namespace impl +{ + +namespace round_fn_ns = dpctl::tensor::kernels::round; + +static unary_contig_impl_fn_ptr_t + round_contig_dispatch_vector[td_ns::num_types]; +static int round_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + round_strided_dispatch_vector[td_ns::num_types]; + +void populate_round_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = round_fn_ns; + + using fn_ns::RoundContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(round_contig_dispatch_vector); + + using fn_ns::RoundStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(round_strided_dispatch_vector); + + using fn_ns::RoundTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(round_output_typeid_vector); +} + +} // namespace impl + +// U29: ==== SIGN (x) +namespace impl +{ + +namespace sign_fn_ns = dpctl::tensor::kernels::sign; + +static unary_contig_impl_fn_ptr_t sign_contig_dispatch_vector[td_ns::num_types]; +static int sign_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + sign_strided_dispatch_vector[td_ns::num_types]; + +void populate_sign_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = sign_fn_ns; + + using fn_ns::SignContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(sign_contig_dispatch_vector); + + using fn_ns::SignStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(sign_strided_dispatch_vector); + + using fn_ns::SignTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(sign_output_typeid_vector); +} + +} // namespace impl + +// ==== SIGNBIT (x) +namespace impl +{ + +namespace signbit_fn_ns = dpctl::tensor::kernels::signbit; + +static unary_contig_impl_fn_ptr_t + signbit_contig_dispatch_vector[td_ns::num_types]; +static int signbit_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + signbit_strided_dispatch_vector[td_ns::num_types]; + +void populate_signbit_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = signbit_fn_ns; + + using fn_ns::SignbitContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(signbit_contig_dispatch_vector); + + using fn_ns::SignbitStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(signbit_strided_dispatch_vector); + + using fn_ns::SignbitTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(signbit_output_typeid_vector); +} + +} // namespace impl + +// U30: ==== SIN (x) +namespace impl +{ + +namespace sin_fn_ns = dpctl::tensor::kernels::sin; + +static unary_contig_impl_fn_ptr_t sin_contig_dispatch_vector[td_ns::num_types]; +static int sin_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + sin_strided_dispatch_vector[td_ns::num_types]; + +void populate_sin_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = sin_fn_ns; + + using fn_ns::SinContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(sin_contig_dispatch_vector); + + using fn_ns::SinStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(sin_strided_dispatch_vector); + + using fn_ns::SinTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(sin_output_typeid_vector); +} + +} // namespace impl + +// U31: ==== SINH (x) +namespace impl +{ + +namespace sinh_fn_ns = dpctl::tensor::kernels::sinh; + +static unary_contig_impl_fn_ptr_t sinh_contig_dispatch_vector[td_ns::num_types]; +static int sinh_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + sinh_strided_dispatch_vector[td_ns::num_types]; + +void populate_sinh_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = sinh_fn_ns; + + using fn_ns::SinhContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(sinh_contig_dispatch_vector); + + using fn_ns::SinhStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(sinh_strided_dispatch_vector); + + using fn_ns::SinhTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(sinh_output_typeid_vector); +} + +} // namespace impl + +// U32: ==== SQUARE (x) +namespace impl +{ + +namespace square_fn_ns = dpctl::tensor::kernels::square; + +static unary_contig_impl_fn_ptr_t + square_contig_dispatch_vector[td_ns::num_types]; +static int square_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + square_strided_dispatch_vector[td_ns::num_types]; + +void populate_square_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = square_fn_ns; + + using fn_ns::SquareContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(square_contig_dispatch_vector); + + using fn_ns::SquareStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(square_strided_dispatch_vector); + + using fn_ns::SquareTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(square_output_typeid_vector); +} + +} // namespace impl + +// U33: ==== SQRT (x) +namespace impl +{ + +namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; + +static unary_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; +static int sqrt_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + sqrt_strided_dispatch_vector[td_ns::num_types]; + +void populate_sqrt_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = sqrt_fn_ns; + + using fn_ns::SqrtContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); + + using fn_ns::SqrtStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); + + using fn_ns::SqrtTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(sqrt_output_typeid_vector); +} + +} // namespace impl + +// B23: ==== SUBTRACT (x1, x2) +namespace impl +{ +namespace subtract_fn_ns = dpctl::tensor::kernels::subtract; + +static binary_contig_impl_fn_ptr_t + subtract_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int subtract_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + subtract_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// sub(matrix, row) +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + subtract_contig_matrix_contig_row_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +// sub(row, matrix) +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + subtract_contig_row_contig_matrix_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + subtract_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + subtract_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t + subtract_inplace_row_matrix_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_subtract_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = subtract_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::SubtractTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(subtract_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::SubtractStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(subtract_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::SubtractContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(subtract_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::SubtractContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + SubtractContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + subtract_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::SubtractContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + SubtractContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + subtract_contig_row_contig_matrix_broadcast_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::SubtractInplaceStridedFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(subtract_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::SubtractInplaceContigFactory; + DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(subtract_inplace_contig_dispatch_table); + + // function pointers for inplace operation on contiguous matrix + // and contiguous row + using fn_ns::SubtractInplaceRowMatrixBroadcastFactory; + DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(subtract_inplace_row_matrix_dispatch_table); +}; + +} // namespace impl + +// U34: ==== TAN (x) +namespace impl +{ + +namespace tan_fn_ns = dpctl::tensor::kernels::tan; + +static unary_contig_impl_fn_ptr_t tan_contig_dispatch_vector[td_ns::num_types]; +static int tan_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + tan_strided_dispatch_vector[td_ns::num_types]; + +void populate_tan_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = tan_fn_ns; + + using fn_ns::TanContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(tan_contig_dispatch_vector); + + using fn_ns::TanStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(tan_strided_dispatch_vector); + + using fn_ns::TanTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(tan_output_typeid_vector); +} + +} // namespace impl + +// U35: ==== TANH (x) +namespace impl +{ + +namespace tanh_fn_ns = dpctl::tensor::kernels::tanh; + +static unary_contig_impl_fn_ptr_t tanh_contig_dispatch_vector[td_ns::num_types]; +static int tanh_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + tanh_strided_dispatch_vector[td_ns::num_types]; + +void populate_tanh_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = tanh_fn_ns; + + using fn_ns::TanhContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(tanh_contig_dispatch_vector); + + using fn_ns::TanhStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(tanh_strided_dispatch_vector); + + using fn_ns::TanhTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(tanh_output_typeid_vector); +} + +} // namespace impl + +// U36: ==== TRUNC (x) +namespace impl +{ + +namespace trunc_fn_ns = dpctl::tensor::kernels::trunc; + +static unary_contig_impl_fn_ptr_t + trunc_contig_dispatch_vector[td_ns::num_types]; +static int trunc_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + trunc_strided_dispatch_vector[td_ns::num_types]; + +void populate_trunc_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = trunc_fn_ns; + + using fn_ns::TruncContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(trunc_contig_dispatch_vector); + + using fn_ns::TruncStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(trunc_strided_dispatch_vector); + + using fn_ns::TruncTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(trunc_output_typeid_vector); +} + +} // namespace impl + +// B24: ==== HYPOT (x1, x2) +namespace impl +{ +namespace hypot_fn_ns = dpctl::tensor::kernels::hypot; + +static binary_contig_impl_fn_ptr_t + hypot_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int hypot_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + hypot_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_hypot_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = hypot_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::HypotTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(hypot_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::HypotStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(hypot_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::HypotContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(hypot_contig_dispatch_table); +}; + +} // namespace impl + +// U37: ==== CBRT (x) +namespace impl +{ + +namespace cbrt_fn_ns = dpctl::tensor::kernels::cbrt; + +static unary_contig_impl_fn_ptr_t cbrt_contig_dispatch_vector[td_ns::num_types]; +static int cbrt_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + cbrt_strided_dispatch_vector[td_ns::num_types]; + +void populate_cbrt_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = cbrt_fn_ns; + + using fn_ns::CbrtContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(cbrt_contig_dispatch_vector); + + using fn_ns::CbrtStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(cbrt_strided_dispatch_vector); + + using fn_ns::CbrtTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(cbrt_output_typeid_vector); +} + +} // namespace impl + +// B24: ==== COPYSIGN (x1, x2) +namespace impl +{ +namespace copysign_fn_ns = dpctl::tensor::kernels::copysign; + +static binary_contig_impl_fn_ptr_t + copysign_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int copysign_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + copysign_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_copysign_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = copysign_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::CopysignTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(copysign_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::CopysignStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(copysign_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::CopysignContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(copysign_contig_dispatch_table); +}; + +} // namespace impl + +// U38: ==== EXP2 (x) +namespace impl +{ + +namespace exp2_fn_ns = dpctl::tensor::kernels::exp2; + +static unary_contig_impl_fn_ptr_t exp2_contig_dispatch_vector[td_ns::num_types]; +static int exp2_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + exp2_strided_dispatch_vector[td_ns::num_types]; + +void populate_exp2_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = exp2_fn_ns; + + using fn_ns::Exp2ContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(exp2_contig_dispatch_vector); + + using fn_ns::Exp2StridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(exp2_strided_dispatch_vector); + + using fn_ns::Exp2TypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(exp2_output_typeid_vector); +} + +} // namespace impl + +// U39: ==== RSQRT (x) +namespace impl +{ + +namespace rsqrt_fn_ns = dpctl::tensor::kernels::rsqrt; + +static unary_contig_impl_fn_ptr_t + rsqrt_contig_dispatch_vector[td_ns::num_types]; +static int rsqrt_output_typeid_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + rsqrt_strided_dispatch_vector[td_ns::num_types]; + +void populate_rsqrt_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = rsqrt_fn_ns; + + using fn_ns::RsqrtContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(rsqrt_contig_dispatch_vector); + + using fn_ns::RsqrtStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(rsqrt_strided_dispatch_vector); + + using fn_ns::RsqrtTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(rsqrt_output_typeid_vector); +} + +} // namespace impl + +// ========================================================================================== +// // + +namespace py = pybind11; + +void init_elementwise_functions(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + // U01: ==== ABS (x) + { + impl::populate_abs_dispatch_vectors(); + using impl::abs_contig_dispatch_vector; + using impl::abs_output_typeid_vector; + using impl::abs_strided_dispatch_vector; + + auto abs_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, abs_output_typeid_vector, + abs_contig_dispatch_vector, abs_strided_dispatch_vector); + }; + m.def("_abs", abs_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto abs_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, abs_output_typeid_vector); + }; + m.def("_abs_result_type", abs_result_type_pyapi); + } + + // U02: ==== ACOS (x) + { + impl::populate_acos_dispatch_vectors(); + using impl::acos_contig_dispatch_vector; + using impl::acos_output_typeid_vector; + using impl::acos_strided_dispatch_vector; + + auto acos_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, acos_output_typeid_vector, + acos_contig_dispatch_vector, acos_strided_dispatch_vector); + }; + m.def("_acos", acos_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto acos_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, acos_output_typeid_vector); + }; + m.def("_acos_result_type", acos_result_type_pyapi); + } + + // U03: ===== ACOSH (x) + { + impl::populate_acosh_dispatch_vectors(); + using impl::acosh_contig_dispatch_vector; + using impl::acosh_output_typeid_vector; + using impl::acosh_strided_dispatch_vector; + + auto acosh_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, acosh_output_typeid_vector, + acosh_contig_dispatch_vector, acosh_strided_dispatch_vector); + }; + m.def("_acosh", acosh_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto acosh_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + acosh_output_typeid_vector); + }; + m.def("_acosh_result_type", acosh_result_type_pyapi); + } + + // B01: ===== ADD (x1, x2) + { + impl::populate_add_dispatch_tables(); + using impl::add_contig_dispatch_table; + using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; + using impl::add_contig_row_contig_matrix_broadcast_dispatch_table; + using impl::add_output_id_table; + using impl::add_strided_dispatch_table; + + auto add_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, add_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + add_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + add_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + add_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + add_contig_row_contig_matrix_broadcast_dispatch_table); + }; + auto add_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + add_output_id_table); + }; + m.def("_add", add_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_add_result_type", add_result_type_pyapi, ""); + + using impl::add_inplace_contig_dispatch_table; + using impl::add_inplace_row_matrix_dispatch_table; + using impl::add_inplace_strided_dispatch_table; + + auto add_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, add_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + add_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + add_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + add_inplace_row_matrix_dispatch_table); + }; + m.def("_add_inplace", add_inplace_pyapi, "", py::arg("lhs"), + py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } + + // U04: ===== ASIN (x) + { + impl::populate_asin_dispatch_vectors(); + using impl::asin_contig_dispatch_vector; + using impl::asin_output_typeid_vector; + using impl::asin_strided_dispatch_vector; + + auto asin_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, asin_output_typeid_vector, + asin_contig_dispatch_vector, asin_strided_dispatch_vector); + }; + m.def("_asin", asin_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto asin_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, asin_output_typeid_vector); + }; + m.def("_asin_result_type", asin_result_type_pyapi); + } + + // U05: ===== ASINH (x) + { + impl::populate_asinh_dispatch_vectors(); + using impl::asinh_contig_dispatch_vector; + using impl::asinh_output_typeid_vector; + using impl::asinh_strided_dispatch_vector; + + auto asinh_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, asinh_output_typeid_vector, + asinh_contig_dispatch_vector, asinh_strided_dispatch_vector); + }; + m.def("_asinh", asinh_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto asinh_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + asinh_output_typeid_vector); + }; + m.def("_asinh_result_type", asinh_result_type_pyapi); + } + + // U06: ===== ATAN (x) + { + impl::populate_atan_dispatch_vectors(); + using impl::atan_contig_dispatch_vector; + using impl::atan_output_typeid_vector; + using impl::atan_strided_dispatch_vector; + + auto atan_pyapi = [&](arrayT src, arrayT dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, atan_output_typeid_vector, + atan_contig_dispatch_vector, atan_strided_dispatch_vector); + }; + m.def("_atan", atan_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto atan_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, atan_output_typeid_vector); + }; + m.def("_atan_result_type", atan_result_type_pyapi); + } + + // B02: ===== ATAN2 (x1, x2) + { + impl::populate_atan2_dispatch_tables(); + using impl::atan2_contig_dispatch_table; + using impl::atan2_output_id_table; + using impl::atan2_strided_dispatch_table; + + auto atan2_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, atan2_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + atan2_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + atan2_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto atan2_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + atan2_output_id_table); + }; + m.def("_atan2", atan2_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_atan2_result_type", atan2_result_type_pyapi, ""); + } + + // U07: ===== ATANH (x) + { + impl::populate_atanh_dispatch_vectors(); + using impl::atanh_contig_dispatch_vector; + using impl::atanh_output_typeid_vector; + using impl::atanh_strided_dispatch_vector; + + auto atanh_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, atanh_output_typeid_vector, + atanh_contig_dispatch_vector, atanh_strided_dispatch_vector); + }; + m.def("_atanh", atanh_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto atanh_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + atanh_output_typeid_vector); + }; + m.def("_atanh_result_type", atanh_result_type_pyapi); + } + + // B03: ===== BITWISE_AND (x1, x2) + { + impl::populate_bitwise_and_dispatch_tables(); + using impl::bitwise_and_contig_dispatch_table; + using impl::bitwise_and_output_id_table; + using impl::bitwise_and_strided_dispatch_table; + + auto bitwise_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, bitwise_and_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_and_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_and_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_and_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + bitwise_and_output_id_table); + }; + m.def("_bitwise_and", bitwise_and_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); + } + + // B04: ===== BITWISE_LEFT_SHIFT (x1, x2) + { + impl::populate_bitwise_left_shift_dispatch_tables(); + using impl::bitwise_left_shift_contig_dispatch_table; + using impl::bitwise_left_shift_output_id_table; + using impl::bitwise_left_shift_strided_dispatch_table; + + auto bitwise_left_shift_pyapi = [&](const dpctl::tensor::usm_ndarray + &src1, + const dpctl::tensor::usm_ndarray + &src2, + const dpctl::tensor::usm_ndarray + &dst, + sycl::queue &exec_q, + const std::vector + &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, + bitwise_left_shift_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_left_shift_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_left_shift_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_left_shift_result_type_pyapi = + [&](const py::dtype &dtype1, const py::dtype &dtype2) { + return py_binary_ufunc_result_type( + dtype1, dtype2, bitwise_left_shift_output_id_table); + }; + m.def("_bitwise_left_shift", bitwise_left_shift_pyapi, "", + py::arg("src1"), py::arg("src2"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_bitwise_left_shift_result_type", + bitwise_left_shift_result_type_pyapi, ""); + } + + // U08: ===== BITWISE_INVERT (x) + { + impl::populate_bitwise_invert_dispatch_vectors(); + using impl::bitwise_invert_contig_dispatch_vector; + using impl::bitwise_invert_output_typeid_vector; + using impl::bitwise_invert_strided_dispatch_vector; + + auto bitwise_invert_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + bitwise_invert_output_typeid_vector, + bitwise_invert_contig_dispatch_vector, + bitwise_invert_strided_dispatch_vector); + }; + m.def("_bitwise_invert", bitwise_invert_pyapi, "", py::arg("src"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto bitwise_invert_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type( + dtype, bitwise_invert_output_typeid_vector); + }; + m.def("_bitwise_invert_result_type", bitwise_invert_result_type_pyapi); + } + + // B05: ===== BITWISE_OR (x1, x2) + { + impl::populate_bitwise_or_dispatch_tables(); + using impl::bitwise_or_contig_dispatch_table; + using impl::bitwise_or_output_id_table; + using impl::bitwise_or_strided_dispatch_table; + + auto bitwise_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, bitwise_or_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_or_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_or_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_or_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + bitwise_or_output_id_table); + }; + m.def("_bitwise_or", bitwise_or_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); + } + + // B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) + { + impl::populate_bitwise_right_shift_dispatch_tables(); + using impl::bitwise_right_shift_contig_dispatch_table; + using impl::bitwise_right_shift_output_id_table; + using impl::bitwise_right_shift_strided_dispatch_table; + + auto bitwise_right_shift_pyapi = [&](const dpctl::tensor::usm_ndarray + &src1, + const dpctl::tensor::usm_ndarray + &src2, + const dpctl::tensor::usm_ndarray + &dst, + sycl::queue &exec_q, + const std::vector + &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, + bitwise_right_shift_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_right_shift_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_right_shift_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_right_shift_result_type_pyapi = + [&](const py::dtype &dtype1, const py::dtype &dtype2) { + return py_binary_ufunc_result_type( + dtype1, dtype2, bitwise_right_shift_output_id_table); + }; + m.def("_bitwise_right_shift", bitwise_right_shift_pyapi, "", + py::arg("src1"), py::arg("src2"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_bitwise_right_shift_result_type", + bitwise_right_shift_result_type_pyapi, ""); + } + + // B07: ===== BITWISE_XOR (x1, x2) + { + impl::populate_bitwise_xor_dispatch_tables(); + using impl::bitwise_xor_contig_dispatch_table; + using impl::bitwise_xor_output_id_table; + using impl::bitwise_xor_strided_dispatch_table; + + auto bitwise_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, bitwise_xor_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_xor_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_xor_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_xor_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + bitwise_xor_output_id_table); + }; + m.def("_bitwise_xor", bitwise_xor_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); + } + + // U09: ==== CEIL (x) + { + impl::populate_ceil_dispatch_vectors(); + using impl::ceil_contig_dispatch_vector; + using impl::ceil_output_typeid_vector; + using impl::ceil_strided_dispatch_vector; + + auto ceil_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, ceil_output_typeid_vector, + ceil_contig_dispatch_vector, ceil_strided_dispatch_vector); + }; + m.def("_ceil", ceil_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto ceil_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, ceil_output_typeid_vector); + }; + m.def("_ceil_result_type", ceil_result_type_pyapi); + } + + // U10: ==== CONJ (x) + { + impl::populate_conj_dispatch_vectors(); + using impl::conj_contig_dispatch_vector; + using impl::conj_output_typeid_vector; + using impl::conj_strided_dispatch_vector; + + auto conj_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, conj_output_typeid_vector, + conj_contig_dispatch_vector, conj_strided_dispatch_vector); + }; + m.def("_conj", conj_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto conj_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, conj_output_typeid_vector); + }; + m.def("_conj_result_type", conj_result_type_pyapi); + } + + // U11: ==== COS (x) + { + impl::populate_cos_dispatch_vectors(); + using impl::cos_contig_dispatch_vector; + using impl::cos_output_typeid_vector; + using impl::cos_strided_dispatch_vector; + + auto cos_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, cos_output_typeid_vector, + cos_contig_dispatch_vector, cos_strided_dispatch_vector); + }; + m.def("_cos", cos_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto cos_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, cos_output_typeid_vector); + }; + m.def("_cos_result_type", cos_result_type_pyapi); + } + + // U12: ==== COSH (x) + { + impl::populate_cosh_dispatch_vectors(); + using impl::cosh_contig_dispatch_vector; + using impl::cosh_output_typeid_vector; + using impl::cosh_strided_dispatch_vector; + + auto cosh_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, cosh_output_typeid_vector, + cosh_contig_dispatch_vector, cosh_strided_dispatch_vector); + }; + m.def("_cosh", cosh_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto cosh_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, cosh_output_typeid_vector); + }; + m.def("_cosh_result_type", cosh_result_type_pyapi); + } + + // B08: ==== DIVIDE (x1, x2) + { + impl::populate_true_divide_dispatch_tables(); + using impl::true_divide_contig_dispatch_table; + using impl:: + true_divide_contig_matrix_contig_row_broadcast_dispatch_table; + using impl:: + true_divide_contig_row_contig_matrix_broadcast_dispatch_table; + using impl::true_divide_output_id_table; + using impl::true_divide_strided_dispatch_table; + + auto divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, true_divide_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + true_divide_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + true_divide_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + true_divide_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + true_divide_contig_row_contig_matrix_broadcast_dispatch_table); + }; + auto divide_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + true_divide_output_id_table); + }; + m.def("_divide", divide_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_divide_result_type", divide_result_type_pyapi, ""); + + using impl::true_divide_inplace_contig_dispatch_table; + using impl::true_divide_inplace_output_id_table; + using impl::true_divide_inplace_row_matrix_dispatch_table; + using impl::true_divide_inplace_strided_dispatch_table; + + auto divide_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, + true_divide_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + true_divide_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + true_divide_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + true_divide_inplace_row_matrix_dispatch_table); + }; + m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"), + py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } + + // B09: ==== EQUAL (x1, x2) + { + impl::populate_equal_dispatch_tables(); + using impl::equal_contig_dispatch_table; + using impl::equal_output_id_table; + using impl::equal_strided_dispatch_table; + + auto equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, equal_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + equal_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + equal_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto equal_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + equal_output_id_table); + }; + m.def("_equal", equal_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_equal_result_type", equal_result_type_pyapi, ""); + } + + // U13: ==== EXP (x) + { + impl::populate_exp_dispatch_vectors(); + using impl::exp_contig_dispatch_vector; + using impl::exp_output_typeid_vector; + using impl::exp_strided_dispatch_vector; + + auto exp_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, exp_output_typeid_vector, + exp_contig_dispatch_vector, exp_strided_dispatch_vector); + }; + m.def("_exp", exp_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto exp_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, exp_output_typeid_vector); + }; + m.def("_exp_result_type", exp_result_type_pyapi); + } + + // U14: ==== EXPM1 (x) + { + impl::populate_expm1_dispatch_vectors(); + using impl::expm1_contig_dispatch_vector; + using impl::expm1_output_typeid_vector; + using impl::expm1_strided_dispatch_vector; + + auto expm1_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, expm1_output_typeid_vector, + expm1_contig_dispatch_vector, expm1_strided_dispatch_vector); + }; + m.def("_expm1", expm1_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto expm1_result_type_pyapi = [&](const py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, + expm1_output_typeid_vector); + }; + m.def("_expm1_result_type", expm1_result_type_pyapi); + } + + // U15: ==== FLOOR (x) + { + impl::populate_floor_dispatch_vectors(); + using impl::floor_contig_dispatch_vector; + using impl::floor_output_typeid_vector; + using impl::floor_strided_dispatch_vector; + + auto floor_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, floor_output_typeid_vector, + floor_contig_dispatch_vector, floor_strided_dispatch_vector); + }; + m.def("_floor", floor_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto floor_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + floor_output_typeid_vector); + }; + m.def("_floor_result_type", floor_result_type_pyapi); + } + + // B10: ==== FLOOR_DIVIDE (x1, x2) + { + impl::populate_floor_divide_dispatch_tables(); + using impl::floor_divide_contig_dispatch_table; + using impl::floor_divide_output_id_table; + using impl::floor_divide_strided_dispatch_table; + + auto floor_divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, floor_divide_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + floor_divide_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + floor_divide_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto floor_divide_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + floor_divide_output_id_table); + }; + m.def("_floor_divide", floor_divide_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); + + using impl::floor_divide_inplace_contig_dispatch_table; + using impl::floor_divide_inplace_strided_dispatch_table; + + auto floor_divide_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, floor_divide_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + floor_divide_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + floor_divide_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", + py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } + + // B11: ==== GREATER (x1, x2) + { + impl::populate_greater_dispatch_tables(); + using impl::greater_contig_dispatch_table; + using impl::greater_output_id_table; + using impl::greater_strided_dispatch_table; + + auto greater_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, greater_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + greater_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + greater_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto greater_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + greater_output_id_table); + }; + m.def("_greater", greater_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_greater_result_type", greater_result_type_pyapi, ""); + } + + // B12: ==== GREATER_EQUAL (x1, x2) + { + impl::populate_greater_equal_dispatch_tables(); + using impl::greater_equal_contig_dispatch_table; + using impl::greater_equal_output_id_table; + using impl::greater_equal_strided_dispatch_table; + + auto greater_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, greater_equal_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + greater_equal_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + greater_equal_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto greater_equal_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + greater_equal_output_id_table); + }; + m.def("_greater_equal", greater_equal_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_greater_equal_result_type", greater_equal_result_type_pyapi, + ""); + } + + // U16: ==== IMAG (x) + { + impl::populate_imag_dispatch_vectors(); + using impl::imag_contig_dispatch_vector; + using impl::imag_output_typeid_vector; + using impl::imag_strided_dispatch_vector; + + auto imag_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, imag_output_typeid_vector, + imag_contig_dispatch_vector, imag_strided_dispatch_vector); + }; + m.def("_imag", imag_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto imag_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, imag_output_typeid_vector); + }; + m.def("_imag_result_type", imag_result_type_pyapi); + } + + // U17: ==== ISFINITE (x) + { + impl::populate_isfinite_dispatch_vectors(); + + using impl::isfinite_contig_dispatch_vector; + using impl::isfinite_output_typeid_vector; + using impl::isfinite_strided_dispatch_vector; + auto isfinite_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + isfinite_output_typeid_vector, + isfinite_contig_dispatch_vector, + isfinite_strided_dispatch_vector); + }; + auto isfinite_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + isfinite_output_typeid_vector); + }; + m.def("_isfinite", isfinite_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_isfinite_result_type", isfinite_result_type_pyapi, ""); + } + + // U18: ==== ISINF (x) + { + impl::populate_isinf_dispatch_vectors(); + + using impl::isinf_contig_dispatch_vector; + using impl::isinf_output_typeid_vector; + using impl::isinf_strided_dispatch_vector; + auto isinf_pyapi = [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, isinf_output_typeid_vector, + isinf_contig_dispatch_vector, isinf_strided_dispatch_vector); + }; + auto isinf_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + isinf_output_typeid_vector); + }; + m.def("_isinf", isinf_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_isinf_result_type", isinf_result_type_pyapi, ""); + } + + // U19: ==== ISNAN (x) + { + impl::populate_isnan_dispatch_vectors(); + + using impl::isnan_contig_dispatch_vector; + using impl::isnan_output_typeid_vector; + using impl::isnan_strided_dispatch_vector; + auto isnan_pyapi = [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, isnan_output_typeid_vector, + isnan_contig_dispatch_vector, isnan_strided_dispatch_vector); + }; + auto isnan_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + isnan_output_typeid_vector); + }; + m.def("_isnan", isnan_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_isnan_result_type", isnan_result_type_pyapi, ""); + } + + // B13: ==== LESS (x1, x2) + { + impl::populate_less_dispatch_tables(); + using impl::less_contig_dispatch_table; + using impl::less_output_id_table; + using impl::less_strided_dispatch_table; + + auto less_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, less_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + less_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + less_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto less_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + less_output_id_table); + }; + m.def("_less", less_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_less_result_type", less_result_type_pyapi, ""); + } + + // B14: ==== LESS_EQUAL (x1, x2) + { + impl::populate_less_equal_dispatch_tables(); + using impl::less_equal_contig_dispatch_table; + using impl::less_equal_output_id_table; + using impl::less_equal_strided_dispatch_table; + + auto less_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, less_equal_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + less_equal_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + less_equal_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto less_equal_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + less_equal_output_id_table); + }; + m.def("_less_equal", less_equal_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_less_equal_result_type", less_equal_result_type_pyapi, ""); + } + + // U20: ==== LOG (x) + { + impl::populate_log_dispatch_vectors(); + using impl::log_contig_dispatch_vector; + using impl::log_output_typeid_vector; + using impl::log_strided_dispatch_vector; + + auto log_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, log_output_typeid_vector, + log_contig_dispatch_vector, log_strided_dispatch_vector); + }; + m.def("_log", log_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto log_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, log_output_typeid_vector); + }; + m.def("_log_result_type", log_result_type_pyapi); + } + + // U21: ==== LOG1P (x) + { + impl::populate_log1p_dispatch_vectors(); + using impl::log1p_contig_dispatch_vector; + using impl::log1p_output_typeid_vector; + using impl::log1p_strided_dispatch_vector; + + auto log1p_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, log1p_output_typeid_vector, + log1p_contig_dispatch_vector, log1p_strided_dispatch_vector); + }; + m.def("_log1p", log1p_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto log1p_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + log1p_output_typeid_vector); + }; + m.def("_log1p_result_type", log1p_result_type_pyapi); + } + + // U22: ==== LOG2 (x) + { + impl::populate_log2_dispatch_vectors(); + + using impl::log2_contig_dispatch_vector; + using impl::log2_output_typeid_vector; + using impl::log2_strided_dispatch_vector; + auto log2_pyapi = [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, log2_output_typeid_vector, + log2_contig_dispatch_vector, log2_strided_dispatch_vector); + }; + auto log2_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, log2_output_typeid_vector); + }; + m.def("_log2", log2_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_log2_result_type", log2_result_type_pyapi, ""); + } + + // U23: ==== LOG10 (x) + { + impl::populate_log10_dispatch_vectors(); + + using impl::log10_contig_dispatch_vector; + using impl::log10_output_typeid_vector; + using impl::log10_strided_dispatch_vector; + auto log10_pyapi = [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, log10_output_typeid_vector, + log10_contig_dispatch_vector, log10_strided_dispatch_vector); + }; + auto log10_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + log10_output_typeid_vector); + }; + m.def("_log10", log10_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_log10_result_type", log10_result_type_pyapi, ""); + } + + // B15: ==== LOGADDEXP (x1, x2) + { + impl::populate_logaddexp_dispatch_tables(); + using impl::logaddexp_contig_dispatch_table; + using impl::logaddexp_output_id_table; + using impl::logaddexp_strided_dispatch_table; + + auto logaddexp_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, logaddexp_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + logaddexp_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + logaddexp_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto logaddexp_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + logaddexp_output_id_table); + }; + m.def("_logaddexp", logaddexp_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_logaddexp_result_type", logaddexp_result_type_pyapi, ""); + } + + // B16: ==== LOGICAL_AND (x1, x2) + { + impl::populate_logical_and_dispatch_tables(); + using impl::logical_and_contig_dispatch_table; + using impl::logical_and_output_id_table; + using impl::logical_and_strided_dispatch_table; + + auto logical_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, logical_and_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + logical_and_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + logical_and_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto logical_and_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + logical_and_output_id_table); + }; + m.def("_logical_and", logical_and_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_logical_and_result_type", logical_and_result_type_pyapi, ""); + } + + // U24: ==== LOGICAL_NOT (x) + { + impl::populate_logical_not_dispatch_vectors(); + using impl::logical_not_contig_dispatch_vector; + using impl::logical_not_output_typeid_vector; + using impl::logical_not_strided_dispatch_vector; + + auto logical_not_pyapi = [&](const arrayT &src, arrayT dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + logical_not_output_typeid_vector, + logical_not_contig_dispatch_vector, + logical_not_strided_dispatch_vector); + }; + m.def("_logical_not", logical_not_pyapi, "", py::arg("src"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto logical_not_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + logical_not_output_typeid_vector); + }; + m.def("_logical_not_result_type", logical_not_result_type_pyapi); + } + + // B17: ==== LOGICAL_OR (x1, x2) + { + impl::populate_logical_or_dispatch_tables(); + using impl::logical_or_contig_dispatch_table; + using impl::logical_or_output_id_table; + using impl::logical_or_strided_dispatch_table; + + auto logical_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, logical_or_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + logical_or_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + logical_or_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto logical_or_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + logical_or_output_id_table); + }; + m.def("_logical_or", logical_or_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_logical_or_result_type", logical_or_result_type_pyapi, ""); + } + + // B18: ==== LOGICAL_XOR (x1, x2) + { + impl::populate_logical_xor_dispatch_tables(); + using impl::logical_xor_contig_dispatch_table; + using impl::logical_xor_output_id_table; + using impl::logical_xor_strided_dispatch_table; + + auto logical_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, logical_xor_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + logical_xor_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + logical_xor_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto logical_xor_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + logical_xor_output_id_table); + }; + m.def("_logical_xor", logical_xor_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_logical_xor_result_type", logical_xor_result_type_pyapi, ""); + } + + // B??: ==== MAXIMUM (x1, x2) + { + impl::populate_maximum_dispatch_tables(); + using impl::maximum_contig_dispatch_table; + using impl::maximum_output_id_table; + using impl::maximum_strided_dispatch_table; + + auto maximum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, maximum_output_id_table, + // function pointers to handle operation on contiguous + // arrays (pointers may be nullptr) + maximum_contig_dispatch_table, + // function pointers to handle operation on strided arrays + // (most general case) + maximum_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto maximum_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + maximum_output_id_table); + }; + m.def("_maximum", maximum_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_maximum_result_type", maximum_result_type_pyapi, ""); + } + + // B??: ==== MINIMUM (x1, x2) + { + impl::populate_minimum_dispatch_tables(); + using impl::minimum_contig_dispatch_table; + using impl::minimum_output_id_table; + using impl::minimum_strided_dispatch_table; + + auto minimum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, minimum_output_id_table, + // function pointers to handle operation on contiguous + // arrays (pointers may be nullptr) + minimum_contig_dispatch_table, + // function pointers to handle operation on strided arrays + // (most general case) + minimum_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto minimum_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + minimum_output_id_table); + }; + m.def("_minimum", minimum_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_minimum_result_type", minimum_result_type_pyapi, ""); + } + + // B19: ==== MULTIPLY (x1, x2) + { + impl::populate_multiply_dispatch_tables(); + using impl::multiply_contig_dispatch_table; + using impl::multiply_contig_matrix_contig_row_broadcast_dispatch_table; + using impl::multiply_contig_row_contig_matrix_broadcast_dispatch_table; + using impl::multiply_output_id_table; + using impl::multiply_strided_dispatch_table; + + auto multiply_pyapi = + [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, multiply_output_id_table, + // function pointers to handle operation on contiguous + // arrays (pointers may be nullptr) + multiply_contig_dispatch_table, + // function pointers to handle operation on strided arrays + // (most general case) + multiply_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + multiply_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + multiply_contig_row_contig_matrix_broadcast_dispatch_table); + }; + auto multiply_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + multiply_output_id_table); + }; + m.def("_multiply", multiply_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_multiply_result_type", multiply_result_type_pyapi, ""); + + using impl::multiply_inplace_contig_dispatch_table; + using impl::multiply_inplace_row_matrix_dispatch_table; + using impl::multiply_inplace_strided_dispatch_table; + + auto multiply_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, multiply_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + multiply_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + multiply_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + multiply_inplace_row_matrix_dispatch_table); + }; + m.def("_multiply_inplace", multiply_inplace_pyapi, "", py::arg("lhs"), + py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } + + // U25: ==== NEGATIVE (x) + { + impl::populate_negative_dispatch_vectors(); + using impl::negative_contig_dispatch_vector; + using impl::negative_output_typeid_vector; + using impl::negative_strided_dispatch_vector; + + auto negative_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + negative_output_typeid_vector, + negative_contig_dispatch_vector, + negative_strided_dispatch_vector); + }; + m.def("_negative", negative_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto negative_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + negative_output_typeid_vector); + }; + m.def("_negative_result_type", negative_result_type_pyapi); + } + + // B20: ==== NOT_EQUAL (x1, x2) + { + impl::populate_not_equal_dispatch_tables(); + using impl::not_equal_contig_dispatch_table; + using impl::not_equal_output_id_table; + using impl::not_equal_strided_dispatch_table; + + auto not_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, not_equal_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + not_equal_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + not_equal_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto not_equal_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + not_equal_output_id_table); + }; + m.def("_not_equal", not_equal_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_not_equal_result_type", not_equal_result_type_pyapi, ""); + } + + // U26: ==== POSITIVE (x) + { + impl::populate_positive_dispatch_vectors(); + using impl::positive_contig_dispatch_vector; + using impl::positive_output_typeid_vector; + using impl::positive_strided_dispatch_vector; + + auto positive_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + positive_output_typeid_vector, + positive_contig_dispatch_vector, + positive_strided_dispatch_vector); + }; + m.def("_positive", positive_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto positive_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + positive_output_typeid_vector); + }; + m.def("_positive_result_type", positive_result_type_pyapi); + } + + // B21: ==== POW (x1, x2) + { + impl::populate_pow_dispatch_tables(); + using impl::pow_contig_dispatch_table; + using impl::pow_output_id_table; + using impl::pow_strided_dispatch_table; + + auto pow_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, pow_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + pow_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + pow_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto pow_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + pow_output_id_table); + }; + m.def("_pow", pow_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_pow_result_type", pow_result_type_pyapi, ""); + } + + // U??: ==== PROJ (x) + { + impl::populate_proj_dispatch_vectors(); + using impl::proj_contig_dispatch_vector; + using impl::proj_output_typeid_vector; + using impl::proj_strided_dispatch_vector; + + auto proj_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, proj_output_typeid_vector, + proj_contig_dispatch_vector, proj_strided_dispatch_vector); + }; + m.def("_proj", proj_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto proj_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, proj_output_typeid_vector); + }; + m.def("_proj_result_type", proj_result_type_pyapi); + } + + // U27: ==== REAL (x) + { + impl::populate_real_dispatch_vectors(); + using impl::real_contig_dispatch_vector; + using impl::real_output_typeid_vector; + using impl::real_strided_dispatch_vector; + + auto real_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, real_output_typeid_vector, + real_contig_dispatch_vector, real_strided_dispatch_vector); + }; + m.def("_real", real_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto real_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, real_output_typeid_vector); + }; + m.def("_real_result_type", real_result_type_pyapi); + } + + // B22: ==== REMAINDER (x1, x2) + { + impl::populate_remainder_dispatch_tables(); + using impl::remainder_contig_dispatch_table; + using impl::remainder_output_id_table; + using impl::remainder_strided_dispatch_table; + + auto remainder_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, remainder_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + remainder_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + remainder_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto remainder_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + remainder_output_id_table); + }; + m.def("_remainder", remainder_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_remainder_result_type", remainder_result_type_pyapi, ""); + } + + // U28: ==== ROUND (x) + { + impl::populate_round_dispatch_vectors(); + using impl::round_contig_dispatch_vector; + using impl::round_output_typeid_vector; + using impl::round_strided_dispatch_vector; + + auto round_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, round_output_typeid_vector, + round_contig_dispatch_vector, round_strided_dispatch_vector); + }; + m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto round_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + round_output_typeid_vector); + }; + m.def("_round_result_type", round_result_type_pyapi); + } + + // U29: ==== SIGN (x) + { + impl::populate_sign_dispatch_vectors(); + using impl::sign_contig_dispatch_vector; + using impl::sign_output_typeid_vector; + using impl::sign_strided_dispatch_vector; + + auto sign_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, sign_output_typeid_vector, + sign_contig_dispatch_vector, sign_strided_dispatch_vector); + }; + m.def("_sign", sign_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sign_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, sign_output_typeid_vector); + }; + m.def("_sign_result_type", sign_result_type_pyapi); + } + + // ==== SIGNBIT (x) + { + impl::populate_signbit_dispatch_vectors(); + using impl::signbit_contig_dispatch_vector; + using impl::signbit_output_typeid_vector; + using impl::signbit_strided_dispatch_vector; + + auto signbit_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + signbit_output_typeid_vector, + signbit_contig_dispatch_vector, + signbit_strided_dispatch_vector); + }; + m.def("_signbit", signbit_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto signbit_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + signbit_output_typeid_vector); + }; + m.def("_signbit_result_type", signbit_result_type_pyapi); + } + + // U30: ==== SIN (x) + { + impl::populate_sin_dispatch_vectors(); + using impl::sin_contig_dispatch_vector; + using impl::sin_output_typeid_vector; + using impl::sin_strided_dispatch_vector; + + auto sin_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, sin_output_typeid_vector, + sin_contig_dispatch_vector, sin_strided_dispatch_vector); + }; + m.def("_sin", sin_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sin_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, sin_output_typeid_vector); + }; + m.def("_sin_result_type", sin_result_type_pyapi); + } + // U31: ==== SINH (x) + { + impl::populate_sinh_dispatch_vectors(); + using impl::sinh_contig_dispatch_vector; + using impl::sinh_output_typeid_vector; + using impl::sinh_strided_dispatch_vector; + + auto sinh_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, sinh_output_typeid_vector, + sinh_contig_dispatch_vector, sinh_strided_dispatch_vector); + }; + m.def("_sinh", sinh_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sinh_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, sinh_output_typeid_vector); + }; + m.def("_sinh_result_type", sinh_result_type_pyapi); + } + + // U32: ==== SQUARE (x) + { + impl::populate_square_dispatch_vectors(); + using impl::square_contig_dispatch_vector; + using impl::square_output_typeid_vector; + using impl::square_strided_dispatch_vector; + + auto square_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, square_output_typeid_vector, + square_contig_dispatch_vector, square_strided_dispatch_vector); + }; + m.def("_square", square_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto square_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + square_output_typeid_vector); + }; + m.def("_square_result_type", square_result_type_pyapi); + } + + // U33: ==== SQRT (x) + { + impl::populate_sqrt_dispatch_vectors(); + using impl::sqrt_contig_dispatch_vector; + using impl::sqrt_output_typeid_vector; + using impl::sqrt_strided_dispatch_vector; + + auto sqrt_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, sqrt_output_typeid_vector, + sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector); + }; + m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sqrt_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector); + }; + m.def("_sqrt_result_type", sqrt_result_type_pyapi); + } + + // B23: ==== SUBTRACT (x1, x2) + { + impl::populate_subtract_dispatch_tables(); + using impl::subtract_contig_dispatch_table; + using impl::subtract_contig_matrix_contig_row_broadcast_dispatch_table; + using impl::subtract_contig_row_contig_matrix_broadcast_dispatch_table; + using impl::subtract_output_id_table; + using impl::subtract_strided_dispatch_table; + + auto subtract_pyapi = + [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, subtract_output_id_table, + // function pointers to handle operation on contiguous + // arrays (pointers may be nullptr) + subtract_contig_dispatch_table, + // function pointers to handle operation on strided arrays + // (most general case) + subtract_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + subtract_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix + // and c-contig row with broadcasting (may be nullptr) + subtract_contig_row_contig_matrix_broadcast_dispatch_table); + }; + auto subtract_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + subtract_output_id_table); + }; + m.def("_subtract", subtract_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_subtract_result_type", subtract_result_type_pyapi, ""); + + using impl::subtract_inplace_contig_dispatch_table; + using impl::subtract_inplace_row_matrix_dispatch_table; + using impl::subtract_inplace_strided_dispatch_table; + + auto subtract_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, subtract_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + subtract_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + subtract_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + subtract_inplace_row_matrix_dispatch_table); + }; + m.def("_subtract_inplace", subtract_inplace_pyapi, "", py::arg("lhs"), + py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } + + // U34: ==== TAN (x) + { + impl::populate_tan_dispatch_vectors(); + using impl::tan_contig_dispatch_vector; + using impl::tan_output_typeid_vector; + using impl::tan_strided_dispatch_vector; + + auto tan_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, tan_output_typeid_vector, + tan_contig_dispatch_vector, tan_strided_dispatch_vector); + }; + m.def("_tan", tan_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto tan_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, tan_output_typeid_vector); + }; + m.def("_tan_result_type", tan_result_type_pyapi); + } + + // U35: ==== TANH (x) + { + impl::populate_tanh_dispatch_vectors(); + using impl::tanh_contig_dispatch_vector; + using impl::tanh_output_typeid_vector; + using impl::tanh_strided_dispatch_vector; + + auto tanh_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, tanh_output_typeid_vector, + tanh_contig_dispatch_vector, tanh_strided_dispatch_vector); + }; + m.def("_tanh", tanh_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto tanh_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, tanh_output_typeid_vector); + }; + m.def("_tanh_result_type", tanh_result_type_pyapi); + } + + // U36: ==== TRUNC (x) + { + impl::populate_trunc_dispatch_vectors(); + using impl::trunc_contig_dispatch_vector; + using impl::trunc_output_typeid_vector; + using impl::trunc_strided_dispatch_vector; + + auto trunc_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, trunc_output_typeid_vector, + trunc_contig_dispatch_vector, trunc_strided_dispatch_vector); + }; + m.def("_trunc", trunc_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto trunc_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + trunc_output_typeid_vector); + }; + m.def("_trunc_result_type", trunc_result_type_pyapi); + } + + // B24: ==== HYPOT (x1, x2) + { + impl::populate_hypot_dispatch_tables(); + using impl::hypot_contig_dispatch_table; + using impl::hypot_output_id_table; + using impl::hypot_strided_dispatch_table; + + auto hypot_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, hypot_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + hypot_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + hypot_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto hypot_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + hypot_output_id_table); + }; + m.def("_hypot", hypot_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_hypot_result_type", hypot_result_type_pyapi, ""); + } + + // U37: ==== CBRT (x) + { + impl::populate_cbrt_dispatch_vectors(); + using impl::cbrt_contig_dispatch_vector; + using impl::cbrt_output_typeid_vector; + using impl::cbrt_strided_dispatch_vector; + + auto cbrt_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, cbrt_output_typeid_vector, + cbrt_contig_dispatch_vector, cbrt_strided_dispatch_vector); + }; + m.def("_cbrt", cbrt_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto cbrt_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, cbrt_output_typeid_vector); + }; + m.def("_cbrt_result_type", cbrt_result_type_pyapi); + } + + // B25: ==== COPYSIGN (x1, x2) + { + impl::populate_copysign_dispatch_tables(); + using impl::copysign_contig_dispatch_table; + using impl::copysign_output_id_table; + using impl::copysign_strided_dispatch_table; + + auto copysign_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = + {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, copysign_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + copysign_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + copysign_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto copysign_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + copysign_output_id_table); + }; + m.def("_copysign", copysign_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_copysign_result_type", copysign_result_type_pyapi, ""); + } + + // U38: ==== EXP2 (x) + { + impl::populate_exp2_dispatch_vectors(); + using impl::exp2_contig_dispatch_vector; + using impl::exp2_output_typeid_vector; + using impl::exp2_strided_dispatch_vector; + + auto exp2_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, exp2_output_typeid_vector, + exp2_contig_dispatch_vector, exp2_strided_dispatch_vector); + }; + m.def("_exp2", exp2_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto exp2_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, exp2_output_typeid_vector); + }; + m.def("_exp2_result_type", exp2_result_type_pyapi); + } + + // U39: ==== RSQRT (x) + { + impl::populate_rsqrt_dispatch_vectors(); + using impl::rsqrt_contig_dispatch_vector; + using impl::rsqrt_output_typeid_vector; + using impl::rsqrt_strided_dispatch_vector; + + auto rsqrt_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, rsqrt_output_typeid_vector, + rsqrt_contig_dispatch_vector, rsqrt_strided_dispatch_vector); + }; + m.def("_rsqrt", rsqrt_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto rsqrt_result_type_pyapi = [&](const py::dtype &dtype) { + return py_unary_ufunc_result_type(dtype, + rsqrt_output_typeid_vector); + }; + m.def("_rsqrt_result_type", rsqrt_result_type_pyapi); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp index 6817a3541c..0410229a0a 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp @@ -25,10 +25,11 @@ #pragma once #include "dpctl4pybind11.hpp" -#include #include #include #include +#include +#include #include #include "elementwise_functions_type_utils.hpp" diff --git a/dpctl/tensor/libtensor/source/eye_ctor.cpp b/dpctl/tensor/libtensor/source/eye_ctor.cpp index 5d7657d047..c768a5e395 100644 --- a/dpctl/tensor/libtensor/source/eye_ctor.cpp +++ b/dpctl/tensor/libtensor/source/eye_ctor.cpp @@ -22,7 +22,7 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===--------------------------------------------------------------------===// -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/eye_ctor.hpp b/dpctl/tensor/libtensor/source/eye_ctor.hpp index 4307e0f3b2..58249f08d7 100644 --- a/dpctl/tensor/libtensor/source/eye_ctor.hpp +++ b/dpctl/tensor/libtensor/source/eye_ctor.hpp @@ -23,7 +23,7 @@ //===--------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index 085bdcaf2a..c8004bfae8 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/full_ctor.hpp b/dpctl/tensor/libtensor/source/full_ctor.hpp index 3894babf1f..66456f9a7f 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.hpp +++ b/dpctl/tensor/libtensor/source/full_ctor.hpp @@ -23,7 +23,7 @@ //===--------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp index a17a229fc1..0fd3d2615d 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -23,13 +23,13 @@ /// dpctl.tensor.put //===----------------------------------------------------------------------===// -#include #include #include #include #include #include #include +#include #include #include "dpctl4pybind11.hpp" diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp index f845f7d23b..011fe670a9 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp @@ -24,7 +24,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/linear_sequences.cpp b/dpctl/tensor/libtensor/source/linear_sequences.cpp index 34db93de12..72d292df5f 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.cpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.cpp @@ -23,10 +23,10 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/linear_sequences.hpp b/dpctl/tensor/libtensor/source/linear_sequences.hpp index 61e613b45f..fd13677680 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.hpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.hpp @@ -23,7 +23,7 @@ //===--------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp new file mode 100644 index 0000000000..00e4a0a076 --- /dev/null +++ b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp @@ -0,0 +1,514 @@ +//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; +// Max +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_max_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MaxOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Min +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_min_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::MinOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::MinOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Sum +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_sum_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Product +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_prod_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); + + using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); +} + +} // namespace impl + +// Argmax +namespace impl +{ + +using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; +static search_reduction_strided_impl_fn_ptr + argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_argmax_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::ArgmaxOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); +} + +} // namespace impl + +// Argmin +namespace impl +{ + +using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; +static search_reduction_strided_impl_fn_ptr + argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_argmin_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + using dpctl::tensor::kernels::ArgminOverAxisTempsStridedFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); +} + +} // namespace impl + +namespace py = pybind11; + +void init_reduction_functions(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + namespace impl = dpctl::tensor::py_internal::impl; + + using dpctl::tensor::py_internal::py_reduction_dtype_supported; + using dpctl::tensor::py_internal::py_reduction_over_axis; + + using dpctl::tensor::py_internal::check_atomic_support; + using dpctl::tensor::py_internal::fixed_decision; + + // MAX + { + using dpctl::tensor::py_internal::impl:: + populate_max_over_axis_dispatch_tables; + populate_max_over_axis_dispatch_tables(); + using impl::max_over_axis0_contig_atomic_dispatch_table; + using impl::max_over_axis1_contig_atomic_dispatch_table; + using impl::max_over_axis_strided_atomic_dispatch_table; + using impl::max_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + max_over_axis_strided_atomic_dispatch_table, + max_over_axis_strided_temps_dispatch_table, + max_over_axis0_contig_atomic_dispatch_table, + max_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_max_over_axis", max_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // MIN + { + using dpctl::tensor::py_internal::impl:: + populate_min_over_axis_dispatch_tables; + populate_min_over_axis_dispatch_tables(); + using impl::min_over_axis0_contig_atomic_dispatch_table; + using impl::min_over_axis1_contig_atomic_dispatch_table; + using impl::min_over_axis_strided_atomic_dispatch_table; + using impl::min_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + min_over_axis_strided_atomic_dispatch_table, + min_over_axis_strided_temps_dispatch_table, + min_over_axis0_contig_atomic_dispatch_table, + min_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_min_over_axis", min_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // SUM + { + using dpctl::tensor::py_internal::impl:: + populate_sum_over_axis_dispatch_tables; + populate_sum_over_axis_dispatch_tables(); + using impl::sum_over_axis0_contig_atomic_dispatch_table; + using impl::sum_over_axis1_contig_atomic_dispatch_table; + using impl::sum_over_axis_strided_atomic_dispatch_table; + using impl::sum_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_over_axis0_contig_atomic_dispatch_table, + sum_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sum_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } + + // PROD + { + using dpctl::tensor::py_internal::impl:: + populate_prod_over_axis_dispatch_tables; + populate_prod_over_axis_dispatch_tables(); + using impl::prod_over_axis0_contig_atomic_dispatch_table; + using impl::prod_over_axis1_contig_atomic_dispatch_table; + using impl::prod_over_axis_strided_atomic_dispatch_table; + using impl::prod_over_axis_strided_temps_dispatch_table; + + const auto &check_atomic_support_size4 = + check_atomic_support; + const auto &check_atomic_support_size8 = + check_atomic_support; + + auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_over_axis0_contig_atomic_dispatch_table, + prod_over_axis1_contig_atomic_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto prod_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + check_atomic_support_size4, check_atomic_support_size8); + }; + m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } + + // ARGMAX + { + using dpctl::tensor::py_internal::impl:: + populate_argmax_over_axis_dispatch_tables; + populate_argmax_over_axis_dispatch_tables(); + using impl::argmax_over_axis_strided_temps_dispatch_table; + + auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_search_over_axis; + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmax_over_axis_strided_temps_dispatch_table); + }; + m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } + + // ARGMIN + { + using dpctl::tensor::py_internal::impl:: + populate_argmin_over_axis_dispatch_tables; + populate_argmin_over_axis_dispatch_tables(); + using impl::argmin_over_axis_strided_temps_dispatch_table; + + auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + using dpctl::tensor::py_internal::py_search_over_axis; + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmin_over_axis_strided_temps_dispatch_table); + }; + m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp new file mode 100644 index 0000000000..e9ccd1d52a --- /dev/null +++ b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp @@ -0,0 +1,691 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions, +/// specifically functions for reductions. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +template +bool check_atomic_support(const sycl::queue &exec_q, + sycl::usm::alloc usm_alloc_type) +{ + bool supports_atomics = false; + + const sycl::device &dev = exec_q.get_device(); + + if constexpr (require_atomic64) { + if (!dev.has(sycl::aspect::atomic64)) + return false; + } + + switch (usm_alloc_type) { + case sycl::usm::alloc::shared: + supports_atomics = dev.has(sycl::aspect::usm_atomic_shared_allocations); + break; + case sycl::usm::alloc::host: + supports_atomics = dev.has(sycl::aspect::usm_atomic_host_allocations); + break; + case sycl::usm::alloc::device: + supports_atomics = true; + break; + default: + supports_atomics = false; + } + + return supports_atomics; +} + +template +bool fixed_decision(const sycl::queue &, sycl::usm::alloc) +{ + return return_value; +} + +/* ====================== dtype supported ======================== */ + +template +bool py_reduction_dtype_supported( + const py::dtype &input_dtype, + const py::dtype &output_dtype, + const std::string &dst_usm_type, + sycl::queue &q, + const fnT &atomic_dispatch_table, + const fnT &temps_dispatch_table, + const CheckAtomicSupportFnT &check_atomic_support_size4, + const CheckAtomicSupportFnT &check_atomic_support_size8) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + // remove_all_extents gets underlying type of table + using fn_ptrT = typename std::remove_all_extents::type; + fn_ptrT fn = nullptr; + + sycl::usm::alloc kind = sycl::usm::alloc::unknown; + + if (dst_usm_type == "device") { + kind = sycl::usm::alloc::device; + } + else if (dst_usm_type == "shared") { + kind = sycl::usm::alloc::shared; + } + else if (dst_usm_type == "host") { + kind = sycl::usm::alloc::host; + } + else { + throw py::value_error("Unrecognized `dst_usm_type` argument."); + } + + bool supports_atomics = false; + + switch (output_dtype.itemsize()) { + case sizeof(float): + { + supports_atomics = check_atomic_support_size4(q, kind); + } break; + case sizeof(double): + { + supports_atomics = check_atomic_support_size8(q, kind); + } break; + } + + if (supports_atomics) { + fn = atomic_dispatch_table[arg_typeid][out_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[arg_typeid][out_typeid]; + } + + return (fn != nullptr); +} + +/* ==================== Generic reductions ====================== */ + +template +std::pair py_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &atomic_dispatch_table, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_dispatch_table, + const contig_fnT &axis1_dispatch_table, + const SupportAtomicFnT &check_atomic_support_size4, + const SupportAtomicFnT &check_atomic_support_size8) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int dst_itemsize = dst.get_elemsize(); + bool supports_atomics = false; + + switch (dst_itemsize) { + case sizeof(float): + { + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + supports_atomics = check_atomic_support_size4(exec_q, usm_type); + } break; + case sizeof(double): + { + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + supports_atomics = check_atomic_support_size8(exec_q, usm_type); + } break; + } + + // handle special case when both reduction and iteration are 1D contiguous + // and can be done with atomics + if (supports_atomics) { + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; + + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + size_t iter_nelems = dst_nelems; + + constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + } + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if (supports_atomics && (reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast(simplified_iteration_src_strides[0]) == + reduction_nelems); + } + else if (static_cast(simplified_reduction_src_strides[0]) == + iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + // remove_all_extents gets underlying type of table + using strided_fn_ptr_T = + typename std::remove_all_extents::type; + strided_fn_ptr_T fn = nullptr; + + if (supports_atomics) { + fn = atomic_dispatch_table[src_typeid][dst_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + } + + std::vector host_task_events{}; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(reduction_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/* ==================== Search reductions ====================== */ + +template +std::pair py_search_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const fn_tableT &dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + size_t dst_nelems = dst.get_size(); + + size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_1; + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT compact_reduction_shape; + shT compact_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + compact_iteration_space( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + compact_reduction_shape, compact_reduction_src_strides); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + auto fn = dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + compact_reduction_shape, compact_reduction_src_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_nd, iter_shape_and_strides, + iteration_src_offset, iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, comp_ev); +} + +extern void init_reduction_functions(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/repeat.cpp b/dpctl/tensor/libtensor/source/repeat.cpp index f3a20cbbaa..fe11684ab9 100644 --- a/dpctl/tensor/libtensor/source/repeat.cpp +++ b/dpctl/tensor/libtensor/source/repeat.cpp @@ -23,11 +23,11 @@ //===--------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include #include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/repeat.hpp b/dpctl/tensor/libtensor/source/repeat.hpp index 65ace36516..2d37aa33e9 100644 --- a/dpctl/tensor/libtensor/source/repeat.hpp +++ b/dpctl/tensor/libtensor/source/repeat.hpp @@ -23,7 +23,7 @@ //===--------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/tensor_ctors.cpp b/dpctl/tensor/libtensor/source/tensor_ctors.cpp index 4720f6baa1..be2b20c18d 100644 --- a/dpctl/tensor/libtensor/source/tensor_ctors.cpp +++ b/dpctl/tensor/libtensor/source/tensor_ctors.cpp @@ -23,12 +23,12 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===----------------------------------------------------------------------===// -#include #include #include #include #include #include +#include #include #include #include diff --git a/dpctl/tensor/libtensor/source/triul_ctor.cpp b/dpctl/tensor/libtensor/source/triul_ctor.cpp index 40dd5cf48a..03fcd2994c 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.cpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.cpp @@ -22,7 +22,7 @@ /// This file defines functions of dpctl.tensor._tensor_impl extensions //===--------------------------------------------------------------------===// -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/triul_ctor.hpp b/dpctl/tensor/libtensor/source/triul_ctor.hpp index 9e7053c638..de80d20407 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.hpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.hpp @@ -23,7 +23,7 @@ //===--------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp index ed782bda34..e533fd2ee1 100644 --- a/dpctl/tensor/libtensor/source/where.cpp +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -24,12 +24,12 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" -#include #include #include #include #include #include +#include #include #include "kernels/where.hpp" diff --git a/dpctl/tensor/libtensor/source/where.hpp b/dpctl/tensor/libtensor/source/where.hpp index 6fe6527080..2ca3b39e02 100644 --- a/dpctl/tensor/libtensor/source/where.hpp +++ b/dpctl/tensor/libtensor/source/where.hpp @@ -24,7 +24,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include #include diff --git a/dpctl/utils/CMakeLists.txt b/dpctl/utils/CMakeLists.txt index aadc1c0fe0..e7d3951e5b 100644 --- a/dpctl/utils/CMakeLists.txt +++ b/dpctl/utils/CMakeLists.txt @@ -21,6 +21,19 @@ pybind11_add_module(${python_module_name} MODULE ${_module_src} ) add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src}) +if(_dpctl_sycl_targets) + # make fat binary + target_compile_options( + ${python_module_name} + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) + target_link_options( + ${python_module_name} + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) +endif() target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../include diff --git a/libsyclinterface/CMakeLists.txt b/libsyclinterface/CMakeLists.txt index 64ec3271b1..e84959c1c3 100644 --- a/libsyclinterface/CMakeLists.txt +++ b/libsyclinterface/CMakeLists.txt @@ -205,6 +205,19 @@ add_library(DPCTLSyclInterface ${helper_sources} ) add_sycl_to_target(TARGET DPCTLSyclInterface SOURCES ${sources} ${helper_sources}) +# make fat binary +if(_dpctl_sycl_targets) + target_compile_options( + DPCTLSyclInterface + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) + target_link_options( + DPCTLSyclInterface + PRIVATE + -fsycl-targets=${_dpctl_sycl_targets} + ) +endif() if(DPCTL_GENERATE_COVERAGE) target_link_options(DPCTLSyclInterface diff --git a/libsyclinterface/helper/include/dpctl_error_handlers.h b/libsyclinterface/helper/include/dpctl_error_handlers.h index 2947dd1d5e..5c7c12f1a4 100644 --- a/libsyclinterface/helper/include/dpctl_error_handlers.h +++ b/libsyclinterface/helper/include/dpctl_error_handlers.h @@ -27,7 +27,7 @@ #include "Support/DllExport.h" #include "dpctl_error_handler_type.h" -#include +#include /*! * @brief Functor class used by DPCTL to handle SYCL asynchronous errors. diff --git a/libsyclinterface/helper/include/dpctl_utils_helper.h b/libsyclinterface/helper/include/dpctl_utils_helper.h index 9ed29514df..f7f484a32a 100644 --- a/libsyclinterface/helper/include/dpctl_utils_helper.h +++ b/libsyclinterface/helper/include/dpctl_utils_helper.h @@ -26,7 +26,7 @@ #include "Support/DllExport.h" #include "dpctl_sycl_enum_types.h" -#include +#include /*! * @brief Converts a sycl::info::device_type input value to a string. diff --git a/libsyclinterface/include/dpctl_device_selection.hpp b/libsyclinterface/include/dpctl_device_selection.hpp index 9da0072ab1..605078586c 100644 --- a/libsyclinterface/include/dpctl_device_selection.hpp +++ b/libsyclinterface/include/dpctl_device_selection.hpp @@ -28,7 +28,7 @@ #pragma once #include "Support/DllExport.h" -#include +#include namespace dpctl { diff --git a/libsyclinterface/include/dpctl_sycl_type_casters.hpp b/libsyclinterface/include/dpctl_sycl_type_casters.hpp index 470165afdd..107fc43ff4 100644 --- a/libsyclinterface/include/dpctl_sycl_type_casters.hpp +++ b/libsyclinterface/include/dpctl_sycl_type_casters.hpp @@ -30,7 +30,7 @@ #include "dpctl_device_selection.hpp" #include "dpctl_sycl_types.h" -#include +#include #include namespace dpctl::syclinterface diff --git a/libsyclinterface/source/dpctl_device_selection.cpp b/libsyclinterface/source/dpctl_device_selection.cpp index 7203bc3b1a..299ca5be41 100644 --- a/libsyclinterface/source/dpctl_device_selection.cpp +++ b/libsyclinterface/source/dpctl_device_selection.cpp @@ -27,7 +27,7 @@ #include "dpctl_device_selection.hpp" #include "Config/dpctl_config.h" -#include +#include namespace { diff --git a/libsyclinterface/source/dpctl_sycl_context_interface.cpp b/libsyclinterface/source/dpctl_sycl_context_interface.cpp index a19286a779..ab9923652c 100644 --- a/libsyclinterface/source/dpctl_sycl_context_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_context_interface.cpp @@ -28,7 +28,7 @@ #include "Config/dpctl_config.h" #include "dpctl_error_handlers.h" #include "dpctl_sycl_type_casters.hpp" -#include +#include #include #include diff --git a/libsyclinterface/source/dpctl_sycl_device_interface.cpp b/libsyclinterface/source/dpctl_sycl_device_interface.cpp index 7a159a331c..72a02e9261 100644 --- a/libsyclinterface/source/dpctl_sycl_device_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_device_interface.cpp @@ -32,9 +32,9 @@ #include "dpctl_sycl_device_manager.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils_helper.h" -#include /* SYCL headers */ #include #include +#include /* SYCL headers */ #include #include diff --git a/libsyclinterface/source/dpctl_sycl_device_manager.cpp b/libsyclinterface/source/dpctl_sycl_device_manager.cpp index 0eb71df412..f36f5db21e 100644 --- a/libsyclinterface/source/dpctl_sycl_device_manager.cpp +++ b/libsyclinterface/source/dpctl_sycl_device_manager.cpp @@ -29,10 +29,10 @@ #include "dpctl_sycl_enum_types.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils_helper.h" -#include /* SYCL headers */ #include /* Config */ #include #include +#include /* SYCL headers */ #include #include #include diff --git a/libsyclinterface/source/dpctl_sycl_device_selector_interface.cpp b/libsyclinterface/source/dpctl_sycl_device_selector_interface.cpp index 9753c32613..834e9a57a2 100644 --- a/libsyclinterface/source/dpctl_sycl_device_selector_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_device_selector_interface.cpp @@ -28,7 +28,7 @@ #include "dpctl_device_selection.hpp" #include "dpctl_error_handlers.h" #include "dpctl_sycl_type_casters.hpp" -#include /* SYCL headers */ +#include /* SYCL headers */ using namespace sycl; diff --git a/libsyclinterface/source/dpctl_sycl_event_interface.cpp b/libsyclinterface/source/dpctl_sycl_event_interface.cpp index 3f872f4493..7a109faca9 100644 --- a/libsyclinterface/source/dpctl_sycl_event_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_event_interface.cpp @@ -29,7 +29,7 @@ #include "dpctl_error_handlers.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils_helper.h" -#include /* SYCL headers */ +#include /* SYCL headers */ #include using namespace sycl; diff --git a/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp b/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp index 201c8172e3..d32f278c07 100644 --- a/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp @@ -30,14 +30,10 @@ #include "dpctl_dynamic_lib_helper.h" #include "dpctl_error_handlers.h" #include "dpctl_sycl_type_casters.hpp" -#include /* OpenCL headers */ -#include /* Sycl headers */ -#if __has_include() -#include -#else -#include -#endif +#include /* OpenCL headers */ #include +#include +#include /* Sycl headers */ #include #ifdef DPCTL_ENABLE_L0_PROGRAM_CREATION @@ -45,11 +41,7 @@ // not reorder the includes. // clang-format off #include "ze_api.h" /* Level Zero headers */ -#if __has_include() #include -#else -#include -#endif // clang-format on #endif diff --git a/libsyclinterface/source/dpctl_sycl_kernel_interface.cpp b/libsyclinterface/source/dpctl_sycl_kernel_interface.cpp index 8a5af3f179..abd7f9a443 100644 --- a/libsyclinterface/source/dpctl_sycl_kernel_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_kernel_interface.cpp @@ -29,8 +29,8 @@ #include "dpctl_error_handlers.h" #include "dpctl_string_utils.hpp" #include "dpctl_sycl_type_casters.hpp" -#include /* Sycl headers */ #include +#include /* Sycl headers */ using namespace sycl; diff --git a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp index fb0fbd6bd2..409b600355 100644 --- a/libsyclinterface/source/dpctl_sycl_platform_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_platform_interface.cpp @@ -31,11 +31,11 @@ #include "dpctl_string_utils.hpp" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils_helper.h" -#include #include #include #include #include +#include #include #include diff --git a/libsyclinterface/source/dpctl_sycl_platform_manager.cpp b/libsyclinterface/source/dpctl_sycl_platform_manager.cpp index 6717b48c6f..f01f7a76a5 100644 --- a/libsyclinterface/source/dpctl_sycl_platform_manager.cpp +++ b/libsyclinterface/source/dpctl_sycl_platform_manager.cpp @@ -31,11 +31,11 @@ #include "dpctl_sycl_platform_interface.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils_helper.h" -#include #include #include #include #include +#include using namespace sycl; diff --git a/libsyclinterface/source/dpctl_sycl_queue_interface.cpp b/libsyclinterface/source/dpctl_sycl_queue_interface.cpp index 60098ae933..6612c22ef9 100644 --- a/libsyclinterface/source/dpctl_sycl_queue_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_queue_interface.cpp @@ -31,9 +31,9 @@ #include "dpctl_sycl_device_interface.h" #include "dpctl_sycl_device_manager.h" #include "dpctl_sycl_type_casters.hpp" -#include /* SYCL headers */ #include #include +#include /* SYCL headers */ #include using namespace sycl; diff --git a/libsyclinterface/source/dpctl_sycl_queue_manager.cpp b/libsyclinterface/source/dpctl_sycl_queue_manager.cpp index 54e97c0efa..651689e105 100644 --- a/libsyclinterface/source/dpctl_sycl_queue_manager.cpp +++ b/libsyclinterface/source/dpctl_sycl_queue_manager.cpp @@ -28,7 +28,7 @@ #include "dpctl_error_handlers.h" #include "dpctl_sycl_device_manager.h" #include "dpctl_sycl_type_casters.hpp" -#include /* SYCL headers */ +#include /* SYCL headers */ #include using namespace sycl; diff --git a/libsyclinterface/source/dpctl_sycl_usm_interface.cpp b/libsyclinterface/source/dpctl_sycl_usm_interface.cpp index 2ebae9801e..b993ee32a8 100644 --- a/libsyclinterface/source/dpctl_sycl_usm_interface.cpp +++ b/libsyclinterface/source/dpctl_sycl_usm_interface.cpp @@ -29,7 +29,7 @@ #include "dpctl_error_handlers.h" #include "dpctl_sycl_device_interface.h" #include "dpctl_sycl_type_casters.hpp" -#include /* SYCL headers */ +#include /* SYCL headers */ #include using namespace sycl; diff --git a/libsyclinterface/tests/CMakeLists.txt b/libsyclinterface/tests/CMakeLists.txt index 472e1787fa..5a672e312f 100644 --- a/libsyclinterface/tests/CMakeLists.txt +++ b/libsyclinterface/tests/CMakeLists.txt @@ -52,6 +52,19 @@ add_sycl_to_target( ${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_queue_interface.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_sycl_usm_interface.cpp ) +if (_dpctl_sycl_targets) +# make fat binary +target_compile_options( + dpctl_c_api_tests + PRIVATE + -fsycl-targets=nvptx64-nvidia-cuda,spir64-unknown-unknown +) +target_link_options( + dpctl_c_api_tests + PRIVATE + -fsycl-targets=nvptx64-nvidia-cuda,spir64-unknown-unknown +) +endif() if(DPCTL_GENERATE_COVERAGE) target_include_directories(dpctl_c_api_tests diff --git a/libsyclinterface/tests/test_helper.cpp b/libsyclinterface/tests/test_helper.cpp index ea529cbf24..467274849d 100644 --- a/libsyclinterface/tests/test_helper.cpp +++ b/libsyclinterface/tests/test_helper.cpp @@ -26,9 +26,9 @@ #include "Config/dpctl_config.h" #include "dpctl_utils_helper.h" -#include #include #include +#include struct TestHelperFns : public ::testing::Test { diff --git a/libsyclinterface/tests/test_sycl_context_interface.cpp b/libsyclinterface/tests/test_sycl_context_interface.cpp index 36b2ff6e97..75fbbe7cb2 100644 --- a/libsyclinterface/tests/test_sycl_context_interface.cpp +++ b/libsyclinterface/tests/test_sycl_context_interface.cpp @@ -29,8 +29,8 @@ #include "dpctl_sycl_device_interface.h" #include "dpctl_sycl_device_selector_interface.h" #include "dpctl_sycl_types.h" -#include #include +#include #include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_device_aspects.cpp b/libsyclinterface/tests/test_sycl_device_aspects.cpp index 9019d7f718..e2e42db74d 100644 --- a/libsyclinterface/tests/test_sycl_device_aspects.cpp +++ b/libsyclinterface/tests/test_sycl_device_aspects.cpp @@ -30,8 +30,8 @@ #include "dpctl_sycl_enum_types.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils_helper.h" -#include #include +#include #include namespace diff --git a/libsyclinterface/tests/test_sycl_device_interface.cpp b/libsyclinterface/tests/test_sycl_device_interface.cpp index dd20c738df..a0544482ff 100644 --- a/libsyclinterface/tests/test_sycl_device_interface.cpp +++ b/libsyclinterface/tests/test_sycl_device_interface.cpp @@ -29,8 +29,8 @@ #include "dpctl_sycl_platform_interface.h" #include "dpctl_utils.h" #include "dpctl_utils_helper.h" -#include #include +#include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_device_invalid_filters.cpp b/libsyclinterface/tests/test_sycl_device_invalid_filters.cpp index c6a722c87a..50cfc6ba67 100644 --- a/libsyclinterface/tests/test_sycl_device_invalid_filters.cpp +++ b/libsyclinterface/tests/test_sycl_device_invalid_filters.cpp @@ -25,8 +25,8 @@ #include "dpctl_sycl_device_interface.h" #include "dpctl_sycl_device_selector_interface.h" -#include #include +#include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_device_selector_interface.cpp b/libsyclinterface/tests/test_sycl_device_selector_interface.cpp index eff7e4ee41..8e5df58769 100644 --- a/libsyclinterface/tests/test_sycl_device_selector_interface.cpp +++ b/libsyclinterface/tests/test_sycl_device_selector_interface.cpp @@ -28,8 +28,8 @@ #include "dpctl_sycl_device_manager.h" #include "dpctl_sycl_device_selector_interface.h" #include "dpctl_sycl_type_casters.hpp" -#include #include +#include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_device_subdevices.cpp b/libsyclinterface/tests/test_sycl_device_subdevices.cpp index cb73359059..f2039c6dcf 100644 --- a/libsyclinterface/tests/test_sycl_device_subdevices.cpp +++ b/libsyclinterface/tests/test_sycl_device_subdevices.cpp @@ -32,8 +32,8 @@ #include "dpctl_sycl_type_casters.hpp" #include "dpctl_utils.h" #include "dpctl_utils_helper.h" -#include #include +#include using namespace sycl; using namespace dpctl::syclinterface; diff --git a/libsyclinterface/tests/test_sycl_event_interface.cpp b/libsyclinterface/tests/test_sycl_event_interface.cpp index 0cc11af731..615755ebc3 100644 --- a/libsyclinterface/tests/test_sycl_event_interface.cpp +++ b/libsyclinterface/tests/test_sycl_event_interface.cpp @@ -27,8 +27,8 @@ #include "Config/dpctl_config.h" #include "dpctl_sycl_event_interface.h" #include "dpctl_sycl_types.h" -#include #include +#include #include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp b/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp index 6383b730a0..c450d6722d 100644 --- a/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp +++ b/libsyclinterface/tests/test_sycl_kernel_bundle_interface.cpp @@ -34,11 +34,11 @@ #include "dpctl_sycl_kernel_interface.h" #include "dpctl_sycl_queue_interface.h" #include "dpctl_sycl_queue_manager.h" -#include #include #include #include #include +#include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_kernel_interface.cpp b/libsyclinterface/tests/test_sycl_kernel_interface.cpp index 97fba96bc3..d7e7cb4087 100644 --- a/libsyclinterface/tests/test_sycl_kernel_interface.cpp +++ b/libsyclinterface/tests/test_sycl_kernel_interface.cpp @@ -33,9 +33,9 @@ #include "dpctl_sycl_queue_interface.h" #include "dpctl_sycl_queue_manager.h" #include "dpctl_utils.h" -#include #include #include +#include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_platform_interface.cpp b/libsyclinterface/tests/test_sycl_platform_interface.cpp index f04cead0e1..3164aef7ec 100644 --- a/libsyclinterface/tests/test_sycl_platform_interface.cpp +++ b/libsyclinterface/tests/test_sycl_platform_interface.cpp @@ -29,8 +29,8 @@ #include "dpctl_sycl_platform_interface.h" #include "dpctl_sycl_platform_manager.h" #include "dpctl_utils.h" -#include #include +#include #include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_platform_invalid_filters.cpp b/libsyclinterface/tests/test_sycl_platform_invalid_filters.cpp index 5499f88430..41f0cc4a53 100644 --- a/libsyclinterface/tests/test_sycl_platform_invalid_filters.cpp +++ b/libsyclinterface/tests/test_sycl_platform_invalid_filters.cpp @@ -26,8 +26,8 @@ #include "dpctl_sycl_device_selector_interface.h" #include "dpctl_sycl_platform_interface.h" -#include #include +#include using namespace sycl; diff --git a/libsyclinterface/tests/test_sycl_queue_interface.cpp b/libsyclinterface/tests/test_sycl_queue_interface.cpp index 836a87379b..7fc0d39970 100644 --- a/libsyclinterface/tests/test_sycl_queue_interface.cpp +++ b/libsyclinterface/tests/test_sycl_queue_interface.cpp @@ -34,8 +34,8 @@ #include "dpctl_sycl_queue_manager.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_sycl_usm_interface.h" -#include #include +#include using namespace sycl; using namespace dpctl::syclinterface; diff --git a/libsyclinterface/tests/test_sycl_queue_manager.cpp b/libsyclinterface/tests/test_sycl_queue_manager.cpp index 4f9e84ea20..0fc640f4ab 100644 --- a/libsyclinterface/tests/test_sycl_queue_manager.cpp +++ b/libsyclinterface/tests/test_sycl_queue_manager.cpp @@ -30,8 +30,8 @@ #include "dpctl_sycl_queue_interface.h" #include "dpctl_sycl_queue_manager.h" #include "dpctl_sycl_type_casters.hpp" -#include #include +#include #include using namespace std; diff --git a/libsyclinterface/tests/test_sycl_queue_submit.cpp b/libsyclinterface/tests/test_sycl_queue_submit.cpp index 680314b719..3ef37978d3 100644 --- a/libsyclinterface/tests/test_sycl_queue_submit.cpp +++ b/libsyclinterface/tests/test_sycl_queue_submit.cpp @@ -32,10 +32,10 @@ #include "dpctl_sycl_queue_interface.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_sycl_usm_interface.h" -#include #include #include #include +#include namespace { diff --git a/libsyclinterface/tests/test_sycl_usm_interface.cpp b/libsyclinterface/tests/test_sycl_usm_interface.cpp index a6dbb2290a..99f8e52051 100644 --- a/libsyclinterface/tests/test_sycl_usm_interface.cpp +++ b/libsyclinterface/tests/test_sycl_usm_interface.cpp @@ -32,9 +32,9 @@ #include "dpctl_sycl_queue_manager.h" #include "dpctl_sycl_type_casters.hpp" #include "dpctl_sycl_usm_interface.h" -#include #include #include +#include using namespace sycl; From c77344caf1071a7efa5373a0c0a2beba0b3ca13c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 16 Nov 2023 16:51:02 -0600 Subject: [PATCH 2/4] Allow for change of name and location of sycl_complex.hpp Introduced private header to load SYCL's experimental complex header from the right location. The header and implementations respond to USE_SYCL_FOR_COMPLEX_TYPES preprocessor variable. If set, sycl::ext::oneapi::experimental namespace functions are to be used. Otherwise std:: namespace functions will be used instead for complex types. USE_SYCL_FOR_COMPLEX_TYPES is being set in tensor/CMakeLists.txt If USE_SYCL_FOR_COMPLEX_TYPES is not set, std:: functions are used except for sqrt and abs functions. For abs we use hypot(std::real(z), std::imag(z)) and for sqrt we use custom implementation on Windows to avoid failure to offload for single precision type due to unwarranted use of double precision types in the implementation for single precision inputs iin MS VC headers --- dpctl/tensor/CMakeLists.txt | 4 +-- .../kernels/elementwise_functions/abs.hpp | 5 ++-- .../kernels/elementwise_functions/acos.hpp | 17 ++++++++++-- .../kernels/elementwise_functions/acosh.hpp | 13 +++++++-- .../kernels/elementwise_functions/add.hpp | 15 +++++++++-- .../kernels/elementwise_functions/asin.hpp | 26 +++++++++++++++--- .../kernels/elementwise_functions/asinh.hpp | 13 +++++++-- .../kernels/elementwise_functions/atan.hpp | 7 +++-- .../kernels/elementwise_functions/atanh.hpp | 7 +++-- .../kernels/elementwise_functions/conj.hpp | 7 +++-- .../kernels/elementwise_functions/cos.hpp | 7 +++-- .../kernels/elementwise_functions/cosh.hpp | 7 +++-- .../kernels/elementwise_functions/equal.hpp | 7 +++-- .../kernels/elementwise_functions/exp.hpp | 7 +++-- .../kernels/elementwise_functions/exp2.hpp | 7 +++-- .../kernels/elementwise_functions/log.hpp | 8 ++++-- .../kernels/elementwise_functions/log10.hpp | 8 ++++-- .../kernels/elementwise_functions/log2.hpp | 8 ++++-- .../elementwise_functions/multiply.hpp | 8 ++++-- .../kernels/elementwise_functions/pow.hpp | 11 ++++++-- .../kernels/elementwise_functions/sign.hpp | 7 +++-- .../kernels/elementwise_functions/sin.hpp | 7 +++-- .../kernels/elementwise_functions/sinh.hpp | 7 +++-- .../kernels/elementwise_functions/sqrt.hpp | 27 ++++++++++++------- .../kernels/elementwise_functions/square.hpp | 7 +++-- .../elementwise_functions/sycl_complex.hpp | 13 +++++++++ .../kernels/elementwise_functions/tan.hpp | 9 ++++--- .../kernels/elementwise_functions/tanh.hpp | 11 +++++--- .../elementwise_functions/true_divide.hpp | 27 ++++++++++++++++--- 29 files changed, 237 insertions(+), 70 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index f2454a9fdc..d6346bfd50 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -188,9 +188,9 @@ foreach(_src_fn ${_no_fast_math_sources}) ) endforeach() if (UNIX) - set(_compiler_definitions "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX") + set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES") else() - set(_compiler_definitions "SYCL_EXT_ONEAPI_COMPLEX") + set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES") endif() foreach(_src_fn ${_elementwise_sources}) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index ab321ad356..911452931e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -28,11 +28,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -50,7 +50,6 @@ namespace abs namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -121,7 +120,7 @@ template struct AbsFunctor return q_nan; } else { -#ifdef USE_STD_ABS_FOR_COMPLEX_TYPES +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::abs(exprm_ns::complex(z)); #else return std::hypot(std::real(z), std::imag(z)); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index 23a87b9d44..c4742e66dc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace acos namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -105,6 +104,7 @@ template struct AcosFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (std::abs(x) > r_eps || std::abs(y) > r_eps) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using sycl_complexT = exprm_ns::complex; sycl_complexT log_in = exprm_ns::log(exprm_ns::complex(in)); @@ -115,11 +115,24 @@ template struct AcosFunctor realT ry = wx + std::log(realT(2)); return resT{rx, (std::signbit(y)) ? ry : -ry}; +#else + resT log_in = std::log(in); + const realT wx = std::real(log_in); + const realT wy = std::imag(log_in); + const realT rx = std::abs(wy); + + realT ry = wx + std::log(realT(2)); + return resT{rx, (std::signbit(y)) ? ry : -ry}; +#endif } /* ordinary cases */ +#if USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::acos( exprm_ns::complex(in)); // std::acos(in); +#else + return std::acos(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index 56730a411c..b736d5b658 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace acosh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -112,18 +111,28 @@ template struct AcoshFunctor * For large x or y including acos(+-Inf + I*+-Inf) */ if (std::abs(x) > r_eps || std::abs(y) > r_eps) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using sycl_complexT = typename exprm_ns::complex; const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in)); const realT wx = log_in.real(); const realT wy = log_in.imag(); +#else + const resT log_in = std::log(in); + const realT wx = std::real(log_in); + const realT wy = std::imag(log_in); +#endif const realT rx = std::abs(wy); realT ry = wx + std::log(realT(2)); acos_in = resT{rx, (std::signbit(y)) ? ry : -ry}; } else { /* ordinary cases */ +#if USE_SYCL_FOR_COMPLEX_TYPES acos_in = exprm_ns::acos( exprm_ns::complex(in)); // std::acos(in); +#else + acos_in = std::acos(in); +#endif } /* Now we calculate acosh(z) */ diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 0ed1710833..1297847831 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -26,10 +26,10 @@ #pragma once #include #include -#include #include #include +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -50,7 +50,6 @@ namespace add namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; template struct AddFunctor { @@ -65,24 +64,36 @@ template struct AddFunctor if constexpr (tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using rT1 = typename argT1::value_type; using rT2 = typename argT2::value_type; return exprm_ns::complex(in1) + exprm_ns::complex(in2); +#else + return in1 + in2; +#endif } else if constexpr (tu_ns::is_complex::value && !tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using rT1 = typename argT1::value_type; return exprm_ns::complex(in1) + in2; +#else + return in1 + in2; +#endif } else if constexpr (!tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using rT2 = typename argT2::value_type; return in1 + exprm_ns::complex(in2); +#else + return in1 + in2; +#endif } else { return in1 + in2; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 035480c437..9da5077665 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace asin namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -119,26 +118,45 @@ template struct AsinFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (std::abs(x) > r_eps || std::abs(y) > r_eps) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using sycl_complexT = exprm_ns::complex; const sycl_complexT z{x, y}; realT wx, wy; if (!std::signbit(x)) { - auto log_z = exprm_ns::log(z); + const auto log_z = exprm_ns::log(z); wx = log_z.real() + std::log(realT(2)); wy = log_z.imag(); } else { - auto log_mz = exprm_ns::log(-z); + const auto log_mz = exprm_ns::log(-z); wx = log_mz.real() + std::log(realT(2)); wy = log_mz.imag(); } +#else + const resT z{x, y}; + realT wx, wy; + if (!std::signbit(x)) { + const auto log_z = std::log(z); + wx = std::real(log_z) + std::log(realT(2)); + wy = std::imag(log_z); + } + else { + const auto log_mz = std::log(-z); + wx = std::real(log_mz) + std::log(realT(2)); + wy = std::imag(log_mz); + } +#endif const realT asinh_re = std::copysign(wx, x); const realT asinh_im = std::copysign(wy, y); return resT{asinh_im, asinh_re}; } /* ordinary cases */ +#if USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::asin( exprm_ns::complex(in)); // std::asin(in); +#else + return std::asin(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 523ca4f01f..ceab296b91 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace asinh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -108,20 +107,30 @@ template struct AsinhFunctor realT(1) / std::numeric_limits::epsilon(); if (std::abs(x) > r_eps || std::abs(y) > r_eps) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using sycl_complexT = exprm_ns::complex; sycl_complexT log_in = (std::signbit(x)) ? exprm_ns::log(sycl_complexT(-in)) : exprm_ns::log(sycl_complexT(in)); realT wx = log_in.real() + std::log(realT(2)); realT wy = log_in.imag(); +#else + auto log_in = std::log(std::signbit(x) ? -in : in); + realT wx = std::real(log_in) + std::log(realT(2)); + realT wy = std::imag(log_in); +#endif const realT res_re = std::copysign(wx, x); const realT res_im = std::copysign(wy, y); return resT{res_re, res_im}; } /* ordinary cases */ +#if USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::asinh( exprm_ns::complex(in)); // std::asinh(in); +#else + return std::asinh(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index df8bba538b..a9af6d829e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +49,6 @@ namespace atan namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -128,8 +127,12 @@ template struct AtanFunctor return resT{atanh_im, atanh_re}; } /* ordinary cases */ +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::atan( exprm_ns::complex(in)); // std::atan(in); +#else + return std::atan(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index d6a4b06ac3..3be6abb742 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +49,6 @@ namespace atanh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -121,8 +120,12 @@ template struct AtanhFunctor return resT{res_re, res_im}; } /* ordinary cases */ +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::atanh( exprm_ns::complex(in)); // std::atanh(in); +#else + return std::atanh(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 6977e3a747..a3c607aa49 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -28,11 +28,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -50,7 +50,6 @@ namespace conj namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -70,9 +69,13 @@ template struct ConjFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using rT = typename argT::value_type; return exprm_ns::conj(exprm_ns::complex(in)); // std::conj(in); +#else + return std::conj(in); +#endif } else { if constexpr (!std::is_same_v) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index bdc1acc1fe..c8cd8ef18c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace cos namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -83,8 +82,12 @@ template struct CosFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::cos( exprm_ns::complex(in)); // std::cos(in); +#else + return std::cos(in); +#endif } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 7093d2a2a3..f03d438cbe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace cosh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -83,8 +82,12 @@ template struct CoshFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::cosh( exprm_ns::complex(in)); // std::cosh(in); +#else + return std::cosh(in); +#endif } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index 6d68861396..9f866aa580 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -26,10 +26,10 @@ #pragma once #include #include -#include #include #include +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -49,7 +49,6 @@ namespace equal namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; template struct EqualFunctor { @@ -70,8 +69,12 @@ template struct EqualFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::complex(in1) == exprm_ns::complex(in2); +#else + return (in1 == in2); +#endif } else { return (in1 == in2); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 453eb05c52..add15c3523 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace exp namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -75,8 +74,12 @@ template struct ExpFunctor const realT y = std::imag(in); if (std::isfinite(x)) { if (std::isfinite(y)) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::exp( exprm_ns::complex(in)); // std::exp(in); +#else + return std::exp(in); +#endif } else { return resT{q_nan, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index b6b2f32e83..a22411414b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +49,6 @@ namespace exp2 namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -78,7 +77,11 @@ template struct Exp2Functor const realT y = std::imag(tmp); if (std::isfinite(x)) { if (std::isfinite(y)) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::exp(exprm_ns::complex(tmp)); +#else + return std::exp(tmp); +#endif } else { return resT{q_nan, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index ff37d87157..3668551473 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -24,13 +24,14 @@ #pragma once #include +#include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +49,6 @@ namespace log namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -68,8 +68,12 @@ template struct LogFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT = typename argT::value_type; return exprm_ns::log(exprm_ns::complex(in)); // std::log(in); +#else + return std::log(in); +#endif } else { return std::log(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index 88dabcaabe..5997aded5f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -25,13 +25,14 @@ #pragma once #include +#include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +50,6 @@ namespace log10 namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -72,9 +72,13 @@ template struct Log10Functor { if constexpr (is_complex::value) { using realT = typename argT::value_type; +#ifdef USE_SYCL_FOR_COMPLEX_TYPES // return (std::log(in) / std::log(realT{10})); return exprm_ns::log(exprm_ns::complex(in)) / std::log(realT{10}); +#else + return (std::log(in) / std::log(realT{10})); +#endif } else { return std::log10(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 57d7dcaf31..211a5fdb6c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -25,13 +25,14 @@ #pragma once #include +#include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +50,6 @@ namespace log2 namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -72,9 +72,13 @@ template struct Log2Functor { if constexpr (is_complex::value) { using realT = typename argT::value_type; +#ifdef USE_SYCL_FOR_COMPLEX_TYPES // std::log(in) / std::log(realT{2}); return exprm_ns::log(exprm_ns::complex(in)) / std::log(realT{2}); +#else + return std::log(in) / std::log(realT{2}); +#endif } else { return std::log2(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index 612ad78360..47549b5a74 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -24,12 +24,13 @@ //===---------------------------------------------------------------------===// #pragma once +#include #include #include -#include #include #include +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -50,7 +51,6 @@ namespace multiply namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; template struct MultiplyFunctor { @@ -65,11 +65,15 @@ template struct MultiplyFunctor if constexpr (tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; return exprm_ns::complex(in1) * exprm_ns::complex(in2); +#else + return in1 * in2; +#endif } else { return in1 * in2; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index 95e8442903..9068e67f10 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -27,10 +27,10 @@ #include #include #include -#include #include #include +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -51,7 +51,6 @@ namespace pow namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; template struct PowFunctor { @@ -88,11 +87,15 @@ template struct PowFunctor else if constexpr (tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; return exprm_ns::pow(exprm_ns::complex(in1), exprm_ns::complex(in2)); +#else + return std::pow(in1, in2); +#endif } else { return std::pow(in1, in2); @@ -365,11 +368,15 @@ template struct PowInplaceFunctor else if constexpr (tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using r_resT = typename resT::value_type; using r_argT = typename argT::value_type; res = exprm_ns::pow(exprm_ns::complex(res), exprm_ns::complex(in)); +#else + res = std::pow(res, in); +#endif } else { res = std::pow(res, in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index 162db394de..521e935c16 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +49,6 @@ namespace sign namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -81,8 +80,12 @@ template struct SignFunctor return resT(0); } else { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES auto z = exprm_ns::complex(in); return (z / exprm_ns::abs(z)); +#else + return in / std::abs(in); +#endif } } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index e1e9e79c57..97768fc8e9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace sin namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -81,8 +80,12 @@ template struct SinFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::sin( exprm_ns::complex(in)); // std::sin(in); +#else + return std::sin(in); +#endif } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index b11c7402d0..fe7dae533b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -26,11 +26,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -48,7 +48,6 @@ namespace sinh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -81,7 +80,11 @@ template struct SinhFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::sinh(exprm_ns::complex(in)); +#else + return std::sinh(in); +#endif } /* * sinh(+-0 +- I Inf) = sign(d(+-0, dNaN))0 + I dNaN. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index b638e4a55f..2d423502a7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -29,11 +29,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -51,7 +51,6 @@ namespace sqrt namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -71,15 +70,25 @@ template struct SqrtFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - // #ifdef _WINDOWS - // return csqrt(in); - // #else - // return std::sqrt(in); - // #endif using realT = typename argT::value_type; - - // return csqrt(in); +#ifdef USE_SYCL_FOR_COMPLEX_TYPES return exprm_ns::sqrt(exprm_ns::complex(in)); +#else +#ifdef _WINDOWS + // Work around a problem on Windows, where std::sqrt for + // single precision input uses double type, precluding + // offloading to devices that do not support double precision + // i.e. Iris Xe + if constexpr (std::is_same_v) { + return std::sqrt(in); + } + else { + return csqrt(in); + } +#else + return std::sqrt(in); +#endif +#endif } else { return std::sqrt(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index 2c37ce87d9..c888c4eefa 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +49,6 @@ namespace square namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -71,11 +70,15 @@ template struct SquareFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT = typename argT::value_type; auto z = exprm_ns::complex(in); return z * z; +#else + return in * in; +#endif } else { return in * in; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp new file mode 100644 index 0000000000..57dc97bb1e --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp @@ -0,0 +1,13 @@ +#pragma once +#ifdef USE_SYCL_FOR_COMPLEX_TYPES +#define SYCL_EXT_ONEAPI_COMPLEX +#if __has_include() +#include +#else +#include +#endif +#endif + +#ifdef USE_SYCL_FOR_COMPLEX_TYPES +namespace exprm_ns = sycl::ext::oneapi::experimental; +#endif diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index 1f97b59054..2c648edd7b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -27,11 +27,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -49,7 +49,6 @@ namespace tan namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace cmplx_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -120,7 +119,11 @@ template struct TanFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return cmplx_ns::tan(cmplx_ns::complex(in)); // std::tan(in); +#ifdef USE_SYCL_FOR_COMPLEX_TYPES + return exprm_ns::tan(exprm_ns::complex(in)); // std::tan(in); +#else + return std::tan(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 453ce17b54..84ba7a9a3f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -28,11 +28,11 @@ #include #include #include -#include #include #include #include "kernels/elementwise_functions/common.hpp" +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -50,7 +50,6 @@ namespace tanh namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace cmplx_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -114,8 +113,12 @@ template struct TanhFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return cmplx_ns::tanh( - cmplx_ns::complex(in)); // std::tanh(in); +#ifdef USE_SYCL_FOR_COMPLEX_TYPES + return exprm_ns::tanh( + exprm_ns::complex(in)); // std::tanh(in); +#else + return std::tanh(in); +#endif } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 6620d2e3c1..c72634d106 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -26,10 +26,10 @@ #pragma once #include #include -#include #include #include +#include "sycl_complex.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -50,7 +50,6 @@ namespace true_divide namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; template struct TrueDivideFunctor @@ -66,25 +65,37 @@ struct TrueDivideFunctor if constexpr (tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; return exprm_ns::complex(in1) / exprm_ns::complex(in2); +#else + return in1 / in2; +#endif } else if constexpr (tu_ns::is_complex::value && !tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT1 = typename argT1::value_type; return exprm_ns::complex(in1) / in2; +#else + return in1 / in2; +#endif } else if constexpr (!tu_ns::is_complex::value && tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES using realT2 = typename argT2::value_type; return in1 / exprm_ns::complex(in2); +#else + return in1 / in2; +#endif } else { return in1 / in2; @@ -409,18 +420,28 @@ template struct TrueDivideInplaceFunctor void operator()(resT &res, const argT &in) { if constexpr (tu_ns::is_complex::value) { - using res_rT = typename resT::value_type; if constexpr (tu_ns::is_complex::value) { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES + using res_rT = typename resT::value_type; using arg_rT = typename argT::value_type; auto res1 = exprm_ns::complex(res); res1 /= exprm_ns::complex(in); res = res1; +#else + res /= in; +#endif } else { +#ifdef USE_SYCL_FOR_COMPLEX_TYPES + using res_rT = typename resT::value_type; + auto res1 = exprm_ns::complex(res); res1 /= in; res = res1; +#else + res /= in; +#endif } } else { From 8405af7a4945b2f22be19dc5ce3abc200e65a155 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 20 Nov 2023 07:11:21 -0600 Subject: [PATCH 3/4] Streamlined cmake script, set USE_SYCL_FOR_COMPLEX_TYPES on all platforms --- dpctl/tensor/CMakeLists.txt | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index d6346bfd50..3d5cd3d544 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -187,11 +187,8 @@ foreach(_src_fn ${_no_fast_math_sources}) PROPERTIES COMPILE_OPTIONS "${_combined_options_prop}" ) endforeach() -if (UNIX) - set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES") -else() - set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES") -endif() + +set(_compiler_definitions "USE_SYCL_FOR_COMPLEX_TYPES") foreach(_src_fn ${_elementwise_sources}) get_source_file_property(_cmpl_options_defs ${_src_fn} COMPILE_DEFINITIONS) From 0efe28bb0584a670269587af2956df72482369e6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 20 Nov 2023 07:17:15 -0600 Subject: [PATCH 4/4] Removed superflous namespace creation line --- .../libtensor/include/kernels/elementwise_functions/isfinite.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index 1554f905b7..cef7c96ed2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -46,7 +46,6 @@ namespace isfinite namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast;