Skip to content

Commit

Permalink
implement dpnp.argmin and dpnp.argmax using dpctl.tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Oct 27, 2023
1 parent 01b3948 commit 63f630c
Show file tree
Hide file tree
Showing 13 changed files with 302 additions and 274 deletions.
4 changes: 0 additions & 4 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() impl */
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() impl */
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
DPNP_FN_ARGMAX_EXT, /**< Used in numpy.argmax() impl, requires extra
parameters */
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
DPNP_FN_ARGMIN_EXT, /**< Used in numpy.argmin() impl, requires extra
parameters */
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
parameters */
Expand Down
50 changes: 0 additions & 50 deletions dpnp/backend/kernels/dpnp_krnl_searching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ void (*dpnp_argmax_default_c)(void *,
void *,
size_t) = dpnp_argmax_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
DPCTLSyclEventRef (*dpnp_argmax_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_argmax_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
class dpnp_argmin_c_kernel;

Expand Down Expand Up @@ -133,14 +125,6 @@ void (*dpnp_argmin_default_c)(void *,
void *,
size_t) = dpnp_argmin_c<_DataType, _idx_DataType>;

template <typename _DataType, typename _idx_DataType>
DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_argmin_c<_DataType, _idx_DataType>;

void func_map_init_searching(func_map_t &fmap)
{
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {
Expand All @@ -160,23 +144,6 @@ void func_map_init_searching(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<int64_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<float, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_INT] = {
eft_INT, (void *)dpnp_argmax_ext_c<double, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmax_ext_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_argmin_default_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_LNG] = {
Expand All @@ -194,22 +161,5 @@ void func_map_init_searching(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<int64_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<float, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {
eft_INT, (void *)dpnp_argmin_ext_c<double, int32_t>};
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_argmin_ext_c<double, int64_t>};

return;
}
10 changes: 0 additions & 10 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_ALLCLOSE
DPNP_FN_ALLCLOSE_EXT
DPNP_FN_ARANGE
DPNP_FN_ARGMAX
DPNP_FN_ARGMAX_EXT
DPNP_FN_ARGMIN
DPNP_FN_ARGMIN_EXT
DPNP_FN_ARGSORT
DPNP_FN_ARGSORT_EXT
DPNP_FN_CBRT
Expand Down Expand Up @@ -375,12 +371,6 @@ Sorting functions
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)

"""
Searching functions
"""
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)

"""
Trigonometric functions
"""
Expand Down
119 changes: 0 additions & 119 deletions dpnp/dpnp_algo/dpnp_algo_searching.pxi

This file was deleted.

32 changes: 4 additions & 28 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,24 +490,14 @@ def argmax(self, axis=None, out=None):
"""
Returns array of indices of the maximum values along the given axis.
Parameters
----------
axis : {None, integer}
If None, the index is into the flattened array, otherwise along
the specified axis
out : {None, array}, optional
Array into which the result can be placed. Its type is preserved
and it must be of the right shape to hold the output.
Returns
-------
index_array : {integer_array}
Refer to :obj:`dpnp.argmax` for full documentation.
Examples
--------
>>> import dpnp as np
>>> a = np.arange(6).reshape(2,3)
>>> a.argmax()
5
array(5)
>>> a.argmax(0)
array([1, 1, 1])
>>> a.argmax(1)
Expand All @@ -520,21 +510,7 @@ def argmin(self, axis=None, out=None):
"""
Return array of indices to the minimum values along the given axis.
Parameters
----------
axis : {None, integer}
If None, the index is into the flattened array, otherwise along
the specified axis
out : {None, array}, optional
Array into which the result can be placed. Its type is preserved
and it must be of the right shape to hold the output.
Returns
-------
ndarray or scalar
If multi-dimension input, returns a new ndarray of indices to the
minimum values along the given axis. Otherwise, returns a scalar
of index to the minimum values along the given axis.
Refer to :obj:`dpnp.argmin` for full documentation.
"""
return dpnp.argmin(self, axis, out)
Expand Down
Loading

0 comments on commit 63f630c

Please sign in to comment.