Skip to content

Commit

Permalink
implement dpnp.argmin and dpnp.argmax using dpctl.tensor (#1610)
Browse files Browse the repository at this point in the history
* rework implementation of diag, diagflat, vander, and ptp

* address comments - first round

cherry-pick

* address comments - second round

* add tests for negative use cases to improve covergae

* fixed missing merge conflicts

* fix pre-commit

* implement dpnp.argmin and dpnp.argmax using dpctl.tensor

* address comments

* add tests for negative use cases to improve coverage

* remove unneccessary parts with updates in dpctl #1465

* add paramater section in doc

* update ndarray.argmin and ndarray.argmax function signature

* use a utility func for returning output

* add tests for ndarray implementation

* Place new function acc to lexicographical order

---------

Co-authored-by: Anton Volkov <antonwolfy@gmail.com>
Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 23, 2023
1 parent 525116a commit 7d0815f
Show file tree
Hide file tree
Showing 18 changed files with 346 additions and 392 deletions.
4 changes: 0 additions & 4 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() impl */
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() impl */
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
DPNP_FN_ARGMAX_EXT, /**< Used in numpy.argmax() impl, requires extra
parameters */
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
DPNP_FN_ARGMIN_EXT, /**< Used in numpy.argmin() impl, requires extra
parameters */
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
parameters */
Expand Down
50 changes: 0 additions & 50 deletions dpnp/backend/kernels/dpnp_krnl_searching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ void (*dpnp_argmax_default_c)(void *,
void *,
size_t) = dpnp_argmax_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
DPCTLSyclEventRef (*dpnp_argmax_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_argmax_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
class dpnp_argmin_c_kernel;

Expand Down Expand Up @@ -133,14 +125,6 @@ void (*dpnp_argmin_default_c)(void *,
void *,
size_t) = dpnp_argmin_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_argmin_c<_DataType, _idx_DataType>;

void func_map_init_searching(func_map_t &fmap)
{
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {
Expand All @@ -160,23 +144,6 @@ void func_map_init_searching(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<int64_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<float, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<double, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_argmin_default_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_LNG] = {
Expand All @@ -194,22 +161,5 @@ void func_map_init_searching(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<int64_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<float, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<double, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<double, int64_t>};

return;
}
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ set(dpnp_algo_pyx_deps
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_sorting.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_arraycreation.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_mathematical.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_searching.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_indexing.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_logic.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_special.pxi
Expand Down
10 changes: 0 additions & 10 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_ALLCLOSE
DPNP_FN_ALLCLOSE_EXT
DPNP_FN_ARANGE
DPNP_FN_ARGMAX
DPNP_FN_ARGMAX_EXT
DPNP_FN_ARGMIN
DPNP_FN_ARGMIN_EXT
DPNP_FN_ARGSORT
DPNP_FN_ARGSORT_EXT
DPNP_FN_CBRT
Expand Down Expand Up @@ -355,12 +351,6 @@ Sorting functions
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)

"""
Searching functions
"""
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)

"""
Trigonometric functions
"""
Expand Down
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/dpnp_algo.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ include "dpnp_algo_indexing.pxi"
include "dpnp_algo_linearalgebra.pxi"
include "dpnp_algo_logic.pxi"
include "dpnp_algo_mathematical.pxi"
include "dpnp_algo_searching.pxi"
include "dpnp_algo_sorting.pxi"
include "dpnp_algo_special.pxi"
include "dpnp_algo_statistics.pxi"
Expand Down
119 changes: 0 additions & 119 deletions dpnp/dpnp_algo/dpnp_algo_searching.pxi

This file was deleted.

47 changes: 6 additions & 41 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,58 +486,23 @@ def any(self, axis=None, out=None, keepdims=False, *, where=True):
self, axis=axis, out=out, keepdims=keepdims, where=where
)

def argmax(self, axis=None, out=None):
def argmax(self, axis=None, out=None, *, keepdims=False):
"""
Returns array of indices of the maximum values along the given axis.
Parameters
----------
axis : {None, integer}
If None, the index is into the flattened array, otherwise along
the specified axis
out : {None, array}, optional
Array into which the result can be placed. Its type is preserved
and it must be of the right shape to hold the output.
Returns
-------
index_array : {integer_array}
Examples
--------
>>> a = np.arange(6).reshape(2,3)
>>> a.argmax()
5
>>> a.argmax(0)
array([1, 1, 1])
>>> a.argmax(1)
array([2, 2])
Refer to :obj:`dpnp.argmax` for full documentation.
"""
return dpnp.argmax(self, axis, out)
return dpnp.argmax(self, axis, out, keepdims=keepdims)

def argmin(self, axis=None, out=None):
def argmin(self, axis=None, out=None, *, keepdims=False):
"""
Return array of indices to the minimum values along the given axis.
Parameters
----------
axis : {None, integer}
If None, the index is into the flattened array, otherwise along
the specified axis
out : {None, array}, optional
Array into which the result can be placed. Its type is preserved
and it must be of the right shape to hold the output.
Returns
-------
ndarray or scalar
If multi-dimension input, returns a new ndarray of indices to the
minimum values along the given axis. Otherwise, returns a scalar
of index to the minimum values along the given axis.
Refer to :obj:`dpnp.argmin` for full documentation.
"""
return dpnp.argmin(self, axis, out)
return dpnp.argmin(self, axis, out, keepdims=keepdims)

# 'argpartition',

Expand Down
46 changes: 46 additions & 0 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"get_dpnp_descriptor",
"get_include",
"get_normalized_queue_device",
"get_result_array",
"get_usm_ndarray",
"get_usm_ndarray_or_scalar",
"is_supported_array_or_scalar",
Expand Down Expand Up @@ -418,6 +419,51 @@ def get_normalized_queue_device(obj=None, device=None, sycl_queue=None):
)


def get_result_array(a, out=None):
"""
If `out` is provided, value of `a` array will be copied into the
`out` array according to ``safe`` casting rule.
Otherwise, the input array `a` is returned.
Parameters
----------
a : {dpnp_array}
Input array.
out : {dpnp_array, usm_ndarray}
If provided, value of `a` array will be copied into it
according to ``safe`` casting rule.
It should be of the appropriate shape.
Returns
-------
out : {dpnp_array}
Return `out` if provided, otherwise return `a`.
"""

if out is None:
return a
else:
if out.shape != a.shape:
raise ValueError(
f"Output array of shape {a.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)

dpnp.copyto(out, a, casting="safe")

return out


def get_usm_ndarray(a):
"""
Return :class:`dpctl.tensor.usm_ndarray` from input array `a`.
Expand Down
Loading

0 comments on commit 7d0815f

Please sign in to comment.