Skip to content

Commit

Permalink
Update dpnp.power using dpctl and OneMKL implementations (#1476)
Browse files Browse the repository at this point in the history
* Reuse dpctl.tensor.pow for dpnp.power

* Add pow call from OneMKL by pybind11 extension

* Update all tests for dpnp.power

* Update examples for dpnp.power

* Update dpnp_power and use OneMKL only on Linux for it

* Restore deleted funcs in test_arithmetic

* Remove dpnp_init_val

* Skip test_copy

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Aug 29, 2023
1 parent b694d87 commit fa3cd55
Show file tree
Hide file tree
Showing 17 changed files with 388 additions and 732 deletions.
81 changes: 81 additions & 0 deletions dpnp/backend/extensions/vm/pow.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>

#include "common.hpp"
#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace vm
{
template <typename T>
sycl::event pow_contig_impl(sycl::queue exec_q,
const std::int64_t n,
const char *in_a,
const char *in_b,
char *out_y,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);

return mkl_vm::pow(exec_q,
n, // number of elements to be calculated
a, // pointer `a` containing 1st input vector of size n
b, // pointer `b` containing 2nd input vector of size n
y, // pointer `y` to the output vector of size n
depends);
}

template <typename fnT, typename T>
struct PowContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename types::PowOutputType<T>::value_type, void>)
{
return nullptr;
}
else {
return pow_contig_impl<T>;
}
}
};
} // namespace vm
} // namespace ext
} // namespace backend
} // namespace dpnp
25 changes: 25 additions & 0 deletions dpnp/backend/extensions/vm/types_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,31 @@ struct MulOutputType
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::pow<T> function.
*
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
*/
template <typename T>
struct PowOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::BinaryTypeMapResultEntry<T,
std::complex<double>,
T,
std::complex<double>,
std::complex<double>>,
dpctl_td_ns::BinaryTypeMapResultEntry<T,
std::complex<float>,
T,
std::complex<float>,
std::complex<float>>,
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::rint<T> function.
Expand Down
32 changes: 32 additions & 0 deletions dpnp/backend/extensions/vm/vm_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "floor.hpp"
#include "ln.hpp"
#include "mul.hpp"
#include "pow.hpp"
#include "round.hpp"
#include "sin.hpp"
#include "sqr.hpp"
Expand All @@ -61,6 +62,7 @@ static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
Expand Down Expand Up @@ -303,6 +305,36 @@ PYBIND11_MODULE(_vm_impl, m)
py::arg("dst"));
}

// BinaryUfunc: ==== Pow(x1, x2) ====
{
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
vm_ext::PowContigFactory>(
pow_dispatch_vector);

auto pow_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
pow_dispatch_vector);
};
m.def("_pow", pow_pyapi,
"Call `pow` function from OneMKL VM library to performs element "
"by element exponentiation of vector `src1` raised to the power "
"of vector `src2` to resulting vector `dst`",
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
py::arg("dst"), py::arg("depends") = py::list());

auto pow_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
arrayT src2, arrayT dst) {
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
pow_dispatch_vector);
};
m.def("_mkl_pow_to_call", pow_need_to_call_pyapi,
"Check input arguments to answer if `pow` function from "
"OneMKL VM library can be used",
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
py::arg("dst"));
}

// UnaryUfunc: ==== Round(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
Expand Down
2 changes: 0 additions & 2 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,6 @@ enum class DPNPFuncName : size_t
parameters */
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
DPNP_FN_POWER, /**< Used in numpy.power() impl */
DPNP_FN_POWER_EXT, /**< Used in numpy.power() impl, requires extra
parameters */
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
DPNP_FN_PROD_EXT, /**< Used in numpy.prod() impl, requires extra parameters
*/
Expand Down
7 changes: 0 additions & 7 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1565,13 +1565,6 @@ static void func_map_elemwise_2arg_3type_core(func_map_t &fmap)
func_type_map_t::find_type<FT1>,
func_type_map_t::find_type<FTs>>}),
...);
((fmap[DPNPFuncName::DPNP_FN_POWER_EXT][FT1][FTs] =
{populate_func_types<FT1, FTs>(),
(void *)dpnp_power_c_ext<
func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
func_type_map_t::find_type<FT1>,
func_type_map_t::find_type<FTs>>}),
...);
((fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][FT1][FTs] =
{populate_func_types<FT1, FTs>(),
(void *)dpnp_subtract_c_ext<
Expand Down
7 changes: 0 additions & 7 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_HYPOT_EXT
DPNP_FN_IDENTITY
DPNP_FN_IDENTITY_EXT
DPNP_FN_INITVAL
DPNP_FN_INITVAL_EXT
DPNP_FN_INV
DPNP_FN_INV_EXT
DPNP_FN_KRON
Expand Down Expand Up @@ -164,8 +162,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_PARTITION
DPNP_FN_PARTITION_EXT
DPNP_FN_PLACE
DPNP_FN_POWER
DPNP_FN_POWER_EXT
DPNP_FN_PROD
DPNP_FN_PROD_EXT
DPNP_FN_PTP
Expand Down Expand Up @@ -407,7 +403,6 @@ cpdef dpnp_descriptor dpnp_matmul(dpnp_descriptor in_array1, dpnp_descriptor in_
"""
Array creation routines
"""
cpdef dpnp_descriptor dpnp_init_val(shape, dtype, value)
cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)

"""
Expand All @@ -421,8 +416,6 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
dpnp_descriptor out=*, object where=*)

"""
Array manipulation routines
Expand Down
37 changes: 0 additions & 37 deletions dpnp/dpnp_algo/dpnp_algo.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ import numpy
__all__ = [
"dpnp_astype",
"dpnp_flatten",
"dpnp_init_val",
"dpnp_queue_initialize",
]

Expand Down Expand Up @@ -85,9 +84,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_flatten_t)(c_dpctl.DPCTLSyclQueueR
const shape_elem_type * , const shape_elem_type * ,
const long * ,
const c_dpctl.DPCTLEventVectorRef)
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_initval_t)(c_dpctl.DPCTLSyclQueueRef,
void *, void * , size_t,
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_astype(utils.dpnp_descriptor x1, dtype):
Expand Down Expand Up @@ -168,39 +164,6 @@ cpdef utils.dpnp_descriptor dpnp_flatten(utils.dpnp_descriptor x1):
return result


cpdef utils.dpnp_descriptor dpnp_init_val(shape, dtype, value):
"""
same as dpnp_full(). TODO remove code duplication
"""
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dtype)

cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_INITVAL_EXT, param1_type, param1_type)

cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(shape, dtype, None)

result_obj = result.get_array()

# TODO: find better way to pass single value with type conversion
cdef utils.dpnp_descriptor val_arr = utils_py.create_output_descriptor_py((1, ),
dtype,
None,
device=result_obj.sycl_device,
usm_type=result_obj.usm_type,
sycl_queue=result_obj.sycl_queue)
val_arr.get_pyobj()[0] = value

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_obj.sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_initval_t func = <fptr_dpnp_initval_t > kernel_data.ptr
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, result.get_data(), val_arr.get_data(), result.size, NULL)

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef dpnp_queue_initialize():
"""
Initialize SYCL queue which will be used for any library operations.
Expand Down
9 changes: 0 additions & 9 deletions dpnp/dpnp_algo/dpnp_algo_mathematical.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ __all__ += [
"dpnp_nancumsum",
"dpnp_nanprod",
"dpnp_nansum",
"dpnp_power",
"dpnp_prod",
"dpnp_sum",
"dpnp_trapz",
Expand Down Expand Up @@ -417,14 +416,6 @@ cpdef utils.dpnp_descriptor dpnp_nansum(utils.dpnp_descriptor x1):
return dpnp_sum(result)


cpdef utils.dpnp_descriptor dpnp_power(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
utils.dpnp_descriptor out=None,
object where=True):
return call_fptr_2in_1out_strides(DPNP_FN_POWER_EXT, x1_obj, x2_obj, dtype, out, where, func_name="power")


cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
object axis=None,
object dtype=None,
Expand Down
66 changes: 66 additions & 0 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# *****************************************************************************


from sys import platform

import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._elementwise_common import (
BinaryElementwiseFunc,
Expand Down Expand Up @@ -68,6 +70,7 @@
"dpnp_multiply",
"dpnp_negative",
"dpnp_not_equal",
"dpnp_power",
"dpnp_proj",
"dpnp_remainder",
"dpnp_right_shift",
Expand Down Expand Up @@ -1460,6 +1463,69 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
return dpnp_array._create_from_usm_ndarray(res_usm)


_power_docstring_ = """
power(x1, x2, out=None, order="K")
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
`x1` with the respective element `x2_i` of the input array `x2`.
Args:
x1 (dpnp.ndarray):
First input array, expected to have numeric data type.
x2 (dpnp.ndarray):
Second input array, also expected to have numeric data type.
out ({None, dpnp.ndarray}, optional):
Output array to populate. Array must have the correct
shape and the expected data type.
order ("C","F","A","K", None, optional):
Output array, if parameter `out` is `None`.
Default: "K".
Returns:
dpnp.ndarray:
An array containing the result of element-wise of raising each element
to a specified power.
The data type of the returned array is determined by the Type Promotion Rules.
"""


def _call_pow(src1, src2, dst, sycl_queue, depends=None):
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""

if depends is None:
depends = []

# TODO: remove this check when OneMKL is fixed on Windows
is_win = platform.startswith("win")

if not is_win and vmi._mkl_pow_to_call(sycl_queue, src1, src2, dst):
# call pybind11 extension for pow() function from OneMKL VM
return vmi._pow(sycl_queue, src1, src2, dst, depends)
return ti._pow(src1, src2, dst, sycl_queue, depends)


pow_func = BinaryElementwiseFunc(
"pow", ti._pow_result_type, _call_pow, _power_docstring_
)


def dpnp_power(x1, x2, out=None, order="K"):
"""
Invokes pow() function from pybind11 extension of OneMKL VM if possible.
Otherwise fully relies on dpctl.tensor implementation for pow() function.
"""

# dpctl.tensor only works with usm_ndarray or scalar
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

res_usm = pow_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
return dpnp_array._create_from_usm_ndarray(res_usm)


_proj_docstring = """
proj(x, out=None, order="K")
Expand Down
Loading

0 comments on commit fa3cd55

Please sign in to comment.