From 79cb518a7fc8e4cd04ec6a84b67c51fa4a77e24b Mon Sep 17 00:00:00 2001 From: vtavana <120411540+vtavana@users.noreply.github.com> Date: Tue, 7 Nov 2023 22:38:38 -0600 Subject: [PATCH] implement dpnp.max and dpnp.min using dpctl.tensor functions (#1602) * implement dpnp.max and dpnp.min using dpctl.tensor functions * address comments * fix a few issues * fix doc-string * add axis==None condition for zero-size array * add new tests to improve coverage * update tests to reduce duplication --- .github/workflows/conda-package.yml | 1 + dpnp/backend/include/dpnp_iface_fptr.hpp | 26 +- dpnp/backend/kernels/dpnp_krnl_statistics.cpp | 42 ---- dpnp/dpnp_algo/dpnp_algo.pxd | 10 - dpnp/dpnp_algo/dpnp_algo_statistics.pxi | 171 ------------- dpnp/dpnp_algo/dpnp_elementwise_common.py | 2 - dpnp/dpnp_array.py | 25 +- dpnp/dpnp_iface_mathematical.py | 12 +- dpnp/dpnp_iface_statistics.py | 233 ++++++++++++------ tests/skipped_tests.tbl | 5 - tests/test_amin_amax.py | 10 +- tests/test_statistics.py | 65 ++++- tests/test_sycl_queue.py | 2 + tests/test_usm_type.py | 2 + .../cupy/core_tests/test_ndarray_reduction.py | 59 +++++ 15 files changed, 318 insertions(+), 347 deletions(-) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 76b24b44287..472c0d6be9a 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -13,6 +13,7 @@ env: # TODO: to add test_arraymanipulation.py back to the scope once crash on Windows is gone TEST_SCOPE: >- test_arraycreation.py + test_amin_amax.py test_dot.py test_dparray.py test_copy.py diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index be64c6727f9..7d8b195935c 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -212,20 +212,18 @@ enum class DPNPFuncName : size_t DPNP_FN_MATRIX_RANK_EXT, /**< Used in numpy.linalg.matrix_rank() impl, requires extra parameters */ DPNP_FN_MAX, /**< Used in numpy.max() impl */ - DPNP_FN_MAX_EXT, /**< Used in numpy.max() impl, requires extra parameters */ - DPNP_FN_MAXIMUM, /**< Used in numpy.fmax() impl */ - DPNP_FN_MAXIMUM_EXT, /**< Used in numpy.fmax() impl , requires extra - parameters */ - DPNP_FN_MEAN, /**< Used in numpy.mean() impl */ - DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */ - DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra - parameters */ - DPNP_FN_MIN, /**< Used in numpy.min() impl */ - DPNP_FN_MIN_EXT, /**< Used in numpy.min() impl, requires extra parameters */ - DPNP_FN_MINIMUM, /**< Used in numpy.fmin() impl */ - DPNP_FN_MINIMUM_EXT, /**< Used in numpy.fmax() impl, requires extra - parameters */ - DPNP_FN_MODF, /**< Used in numpy.modf() impl */ + DPNP_FN_MAXIMUM, /**< Used in numpy.fmax() impl */ + DPNP_FN_MAXIMUM_EXT, /**< Used in numpy.fmax() impl , requires extra + parameters */ + DPNP_FN_MEAN, /**< Used in numpy.mean() impl */ + DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */ + DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra + parameters */ + DPNP_FN_MIN, /**< Used in numpy.min() impl */ + DPNP_FN_MINIMUM, /**< Used in numpy.fmin() impl */ + DPNP_FN_MINIMUM_EXT, /**< Used in numpy.fmax() impl, requires extra + parameters */ + DPNP_FN_MODF, /**< Used in numpy.modf() impl */ DPNP_FN_MODF_EXT, /**< Used in numpy.modf() impl, requires extra parameters */ DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_statistics.cpp b/dpnp/backend/kernels/dpnp_krnl_statistics.cpp index 3acf53f0de4..5c0ca1f6591 100644 --- a/dpnp/backend/kernels/dpnp_krnl_statistics.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_statistics.cpp @@ -503,18 +503,6 @@ void (*dpnp_max_default_c)(void *, const shape_elem_type *, size_t) = dpnp_max_c<_DataType>; -template -DPCTLSyclEventRef (*dpnp_max_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - const size_t, - const shape_elem_type *, - size_t, - const shape_elem_type *, - size_t, - const DPCTLEventVectorRef) = - dpnp_max_c<_DataType>; - template DPCTLSyclEventRef dpnp_mean_c(DPCTLSyclQueueRef q_ref, void *array1_in, @@ -887,18 +875,6 @@ void (*dpnp_min_default_c)(void *, const shape_elem_type *, size_t) = dpnp_min_c<_DataType>; -template -DPCTLSyclEventRef (*dpnp_min_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - const size_t, - const shape_elem_type *, - size_t, - const shape_elem_type *, - size_t, - const DPCTLEventVectorRef) = - dpnp_min_c<_DataType>; - template DPCTLSyclEventRef dpnp_nanvar_c(DPCTLSyclQueueRef q_ref, void *array1_in, @@ -1283,15 +1259,6 @@ void func_map_init_statistics(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_MAX][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_max_default_c}; - fmap[DPNPFuncName::DPNP_FN_MAX_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_max_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MAX_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_max_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MAX_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_max_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MAX_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_max_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MEAN][eft_INT][eft_INT] = { eft_DBL, (void *)dpnp_mean_default_c}; fmap[DPNPFuncName::DPNP_FN_MEAN][eft_LNG][eft_LNG] = { @@ -1340,15 +1307,6 @@ void func_map_init_statistics(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_MIN][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_min_default_c}; - fmap[DPNPFuncName::DPNP_FN_MIN_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_min_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MIN_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_min_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MIN_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_min_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MIN_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_min_ext_c}; - fmap[DPNPFuncName::DPNP_FN_NANVAR][eft_INT][eft_INT] = { eft_INT, (void *)dpnp_nanvar_default_c}; fmap[DPNPFuncName::DPNP_FN_NANVAR][eft_LNG][eft_LNG] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 7a71531c72a..c2e4747fbdb 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -106,14 +106,10 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_MATMUL_EXT DPNP_FN_MATRIX_RANK DPNP_FN_MATRIX_RANK_EXT - DPNP_FN_MAX - DPNP_FN_MAX_EXT DPNP_FN_MAXIMUM DPNP_FN_MAXIMUM_EXT DPNP_FN_MEDIAN DPNP_FN_MEDIAN_EXT - DPNP_FN_MIN - DPNP_FN_MIN_EXT DPNP_FN_MINIMUM DPNP_FN_MINIMUM_EXT DPNP_FN_MODF @@ -369,12 +365,6 @@ Array manipulation routines cpdef dpnp_descriptor dpnp_repeat(dpnp_descriptor array1, repeats, axes=*) -""" -Statistics functions -""" -cpdef dpnp_descriptor dpnp_min(dpnp_descriptor a, axis) - - """ Sorting functions """ diff --git a/dpnp/dpnp_algo/dpnp_algo_statistics.pxi b/dpnp/dpnp_algo/dpnp_algo_statistics.pxi index 43463c7791d..34e0684fcbf 100644 --- a/dpnp/dpnp_algo/dpnp_algo_statistics.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_statistics.pxi @@ -38,9 +38,7 @@ and the rest of the library __all__ += [ "dpnp_average", "dpnp_correlate", - "dpnp_max", "dpnp_median", - "dpnp_min", "dpnp_nanvar", "dpnp_std", "dpnp_var", @@ -64,16 +62,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_statistic_1in_1out_func_ptr_t)(c_dpct void *, void * , shape_elem_type * , size_t, shape_elem_type * , size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_statistic_1in_1out_func_ptr_t_max)(c_dpctl.DPCTLSyclQueueRef, - void *, - void * , - const size_t, - shape_elem_type * , - size_t, - shape_elem_type * , - size_t, - const c_dpctl.DPCTLEventVectorRef) - cdef utils.dpnp_descriptor call_fptr_custom_std_var_1in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1, ddof): cdef shape_type_c x1_shape = x1.shape @@ -177,86 +165,6 @@ cpdef utils.dpnp_descriptor dpnp_correlate(utils.dpnp_descriptor x1, utils.dpnp_ return result -cdef utils.dpnp_descriptor _dpnp_max(utils.dpnp_descriptor x1, _axis_, shape_type_c result_shape): - cdef shape_type_c x1_shape = x1.shape - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MAX_EXT, param1_type, param1_type) - - x1_obj = x1.get_array() - - # create result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef custom_statistic_1in_1out_func_ptr_t_max func = kernel_data.ptr - cdef shape_type_c axis - cdef Py_ssize_t axis_size = 0 - cdef shape_type_c axis_ = axis - - if _axis_ is not None: - axis = _axis_ - axis_.reserve(len(axis)) - for shape_it in axis: - axis_.push_back(shape_it) - axis_size = len(axis) - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - x1.get_data(), - result.get_data(), - result.size, - x1_shape.data(), - x1.ndim, - axis_.data(), - axis_size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - -cpdef utils.dpnp_descriptor dpnp_max(utils.dpnp_descriptor x1, axis): - cdef shape_type_c x1_shape = x1.shape - cdef shape_type_c output_shape - - if axis is None: - axis_ = axis - output_shape.push_back(1) - else: - if isinstance(axis, int): - if axis < 0: - axis_ = tuple([x1.ndim - axis]) - else: - axis_ = tuple([axis]) - else: - _axis_ = [] - for i in range(len(axis)): - if axis[i] < 0: - _axis_.append(x1.ndim - axis[i]) - else: - _axis_.append(axis[i]) - axis_ = tuple(_axis_) - - output_shape.resize(len(x1_shape) - len(axis_), 0) - ind = 0 - for id, shape_axis in enumerate(x1_shape): - if id not in axis_: - output_shape[ind] = shape_axis - ind += 1 - - return _dpnp_max(x1, axis_, output_shape) - cpdef utils.dpnp_descriptor dpnp_median(utils.dpnp_descriptor array1): cdef shape_type_c x1_shape = array1.shape cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype) @@ -301,85 +209,6 @@ cpdef utils.dpnp_descriptor dpnp_median(utils.dpnp_descriptor array1): return result -cpdef utils.dpnp_descriptor _dpnp_min(utils.dpnp_descriptor x1, _axis_, shape_type_c shape_output): - cdef shape_type_c x1_shape = x1.shape - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MIN_EXT, param1_type, param1_type) - - x1_obj = x1.get_array() - - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(shape_output, - kernel_data.return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef custom_statistic_1in_1out_func_ptr_t_max func = kernel_data.ptr - cdef shape_type_c axis - cdef Py_ssize_t axis_size = 0 - cdef shape_type_c axis_ = axis - - if _axis_ is not None: - axis = _axis_ - axis_.reserve(len(axis)) - for shape_it in axis: - if shape_it < 0: - raise ValueError("DPNP algo::_dpnp_min(): Negative values in 'shape' are not allowed") - axis_.push_back(shape_it) - axis_size = len(axis) - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - x1.get_data(), - result.get_data(), - result.size, - x1_shape.data(), - x1.ndim, - axis_.data(), - axis_size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - -cpdef utils.dpnp_descriptor dpnp_min(utils.dpnp_descriptor x1, axis): - cdef shape_type_c x1_shape = x1.shape - cdef shape_type_c shape_output - - if axis is None: - axis_ = axis - shape_output = (1,) - else: - if isinstance(axis, int): - if axis < 0: - axis_ = tuple([x1.ndim - axis]) - else: - axis_ = tuple([axis]) - else: - _axis_ = [] - for i in range(len(axis)): - if axis[i] < 0: - _axis_.append(x1.ndim - axis[i]) - else: - _axis_.append(axis[i]) - axis_ = tuple(_axis_) - - for id, shape_axis in enumerate(x1_shape): - if id not in axis_: - shape_output.push_back(shape_axis) - - return _dpnp_min(x1, axis_, shape_output) - - cpdef utils.dpnp_descriptor dpnp_nanvar(utils.dpnp_descriptor arr, ddof): # dpnp_isnan does not support USM array as input in comparison to dpnp.isnan cdef utils.dpnp_descriptor mask_arr = dpnp.get_dpnp_descriptor(dpnp.isnan(arr.get_pyobj()), diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 39b2199e914..315b266c803 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -1,5 +1,3 @@ -# cython: language_level=3 -# distutils: language = c++ # -*- coding: utf-8 -*- # ***************************************************************************** # Copyright (c) 2023, Intel Corporation diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 9e8a8096a0f..fb454a78642 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -25,7 +25,6 @@ # ***************************************************************************** import dpctl.tensor as dpt -import numpy import dpnp @@ -939,11 +938,15 @@ def max( self, axis=None, out=None, - keepdims=numpy._NoValue, - initial=numpy._NoValue, - where=numpy._NoValue, + keepdims=False, + initial=None, + where=True, ): - """Return the maximum along an axis.""" + """ + Return the maximum along an axis. + + Refer to :obj:`dpnp.max` for full documentation. + """ return dpnp.max(self, axis, out, keepdims, initial, where) @@ -956,11 +959,15 @@ def min( self, axis=None, out=None, - keepdims=numpy._NoValue, - initial=numpy._NoValue, - where=numpy._NoValue, + keepdims=False, + initial=None, + where=True, ): - """Return the minimum along a given axis.""" + """ + Return the minimum along a given axis. + + Refer to :obj:`dpnp.min` for full documentation. + """ return dpnp.min(self, axis, out, keepdims, initial, where) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 416c8492fc9..330179c2ca4 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -1401,9 +1401,9 @@ def maximum( :obj:`dpnp.fmax` : Element-wise maximum of two arrays, ignores NaNs. :obj:`dpnp.amax` : The maximum value of an array along a given axis, propagates NaNs. :obj:`dpnp.nanmax` : The maximum value of an array along a given axis, ignores NaNs. - :obj:`dpnp.fmin` : Element-wise minimum of two arrays, ignores NaNs. - :obj:`dpnp.amix` : The minimum value of an array along a given axis, propagates NaNs. - :obj:`dpnp.nanmix` : The minimum value of an array along a given axis, ignores NaNs. + :obj:`dpnp.fmax` : Element-wise maximum of two arrays, ignores NaNs. + :obj:`dpnp.amax` : The maximum value of an array along a given axis, propagates NaNs. + :obj:`dpnp.nanmax` : The maximum value of an array along a given axis, ignores NaNs. Examples -------- @@ -1480,9 +1480,9 @@ def minimum( :obj:`dpnp.fmin` : Element-wise minimum of two arrays, ignores NaNs. :obj:`dpnp.amin` : The minimum value of an array along a given axis, propagates NaNs. :obj:`dpnp.nanmin` : The minimum value of an array along a given axis, ignores NaNs. - :obj:`dpnp.fmax` : Element-wise maximum of two arrays, ignores NaNs. - :obj:`dpnp.amax` : The maximum value of an array along a given axis, propagates NaNs. - :obj:`dpnp.nanmax` : The maximum value of an array along a given axis, ignores NaNs. + :obj:`dpnp.fmin` : Element-wise minimum of two arrays, ignores NaNs. + :obj:`dpnp.amin` : The minimum value of an array along a given axis, propagates NaNs. + :obj:`dpnp.nanmin` : The minimum value of an array along a given axis, ignores NaNs. Examples -------- diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index c7254ad6d01..653b323c9e1 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -1,5 +1,3 @@ -# cython: language_level=3 -# distutils: language = c++ # -*- coding: utf-8 -*- # ***************************************************************************** # Copyright (c) 2016-2023, Intel Corporation @@ -352,69 +350,102 @@ def histogram(a, bins=10, range=None, density=None, weights=None): ) -def max(x1, axis=None, out=None, keepdims=False, initial=None, where=True): +def max(a, axis=None, out=None, keepdims=False, initial=None, where=True): """ Return the maximum of an array or maximum along an axis. + For full documentation refer to :obj:`numpy.max`. + + Returns + ------- + out : dpnp.ndarray + Maximum of `a`. + Limitations ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Otherwise the function will be executed sequentially on CPU. - Parameter `out` is supported only with default value ``None``. + Input and output arrays are only supported as either :class:`dpnp.ndarray` + or :class:`dpctl.tensor.usm_ndarray`. + Parameters `where`, and `initial` are supported only with their default values. + Otherwise ``NotImplementedError`` exception will be raised. Input array data types are limited by supported DPNP :ref:`Data types`. + See Also + -------- + :obj:`dpnp.min` : Return the minimum of an array. + :obj:`dpnp.maximum` : Element-wise maximum of two arrays, propagates NaNs. + :obj:`dpnp.fmax` : Element-wise maximum of two arrays, ignores NaNs. + :obj:`dpnp.amax` : The maximum value of an array along a given axis, propagates NaNs. + :obj:`dpnp.nanmax` : The maximum value of an array along a given axis, ignores NaNs. + Examples -------- >>> import dpnp as np >>> a = np.arange(4).reshape((2,2)) - >>> a.shape - (2, 2) - >>> [i for i in a] - [0, 1, 2, 3] + >>> a + array([[0, 1], + [2, 3]]) >>> np.max(a) - 3 + array(3) + + >>> np.max(a, axis=0) # Maxima along the first axis + array([2, 3]) + >>> np.max(a, axis=1) # Maxima along the second axis + array([1, 3]) + + >>> b = np.arange(5, dtype=float) + >>> b[2] = np.NaN + >>> np.max(b) + array(nan) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - # Negative values in 'shape' are not allowed in input array - # 306-322 check on negative and duplicate axis - isaxis = True - if axis is not None: - if dpnp.isscalar(axis): - if axis < 0: - isaxis = False - else: - for val in axis: - if val < 0: - isaxis = False - break - if isaxis: - for i in range(len(axis)): - for j in range(len(axis)): - if i != j: - if axis[i] == axis[j]: - isaxis = False - break - - if not isaxis: - pass - elif out is not None: - pass - elif keepdims: - pass - elif initial is not None: - pass - elif where is not True: - pass + if initial is not None: + raise NotImplementedError( + "initial keyword argument is only supported by its default value." + ) + elif where is not True: + raise NotImplementedError( + "where keyword argument is only supported by its default value." + ) + else: + dpt_array = dpnp.get_usm_ndarray(a) + if dpt_array.size == 0: + # TODO: get rid of this if condition when dpctl supports it + axis = (axis,) if isinstance(axis, int) else axis + for i in range(a.ndim): + if a.shape[i] == 0: + if axis is None or i in axis: + raise ValueError( + "reduction does not support zero-size arrays" + ) + else: + indices = [i for i in range(a.ndim) if i not in axis] + res_shape = tuple([a.shape[i] for i in indices]) + result = dpnp.empty(res_shape, dtype=a.dtype) else: - result_obj = dpnp_max(x1_desc, axis).get_pyobj() - result = dpnp.convert_single_elem_array_to_scalar(result_obj) - + result = dpnp_array._create_from_usm_ndarray( + dpt.max(dpt_array, axis=axis, keepdims=keepdims) + ) + if out is None: return result + else: + if out.shape != result.shape: + raise ValueError( + f"Output array of shape {result.shape} is needed, got {out.shape}." + ) + elif not isinstance(out, dpnp_array): + if isinstance(out, dpt.usm_ndarray): + out = dpnp_array._create_from_usm_ndarray(out) + else: + raise TypeError( + "Output array must be any of supported type, but got {}".format( + type(out) + ) + ) + + dpnp.copyto(out, result, casting="safe") - return call_origin(numpy.max, x1, axis, out, keepdims, initial, where) + return out def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True): @@ -564,47 +595,101 @@ def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False): return call_origin(numpy.median, x1, axis, out, overwrite_input, keepdims) -def min(x1, axis=None, out=None, keepdims=False, initial=None, where=True): +def min(a, axis=None, out=None, keepdims=False, initial=None, where=True): """ - Return the minimum along a given axis. + Return the minimum of an array or maximum along an axis. + + For full documentation refer to :obj:`numpy.min`. + + Returns + ------- + out : dpnp.ndarray + Minimum of `a`. Limitations ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Otherwise the function will be executed sequentially on CPU. - Parameter `out` is supported only with default value ``None``. + Input and output arrays are only supported as either :class:`dpnp.ndarray` + or :class:`dpctl.tensor.usm_ndarray`. + Parameters `where`, and `initial` are supported only with their default values. + Otherwise ``NotImplementedError`` exception will be raised. Input array data types are limited by supported DPNP :ref:`Data types`. + See Also + -------- + :obj:`dpnp.max` : Return the maximum of an array. + :obj:`dpnp.minimum` : Element-wise minimum of two arrays, propagates NaNs. + :obj:`dpnp.fmin` : Element-wise minimum of two arrays, ignores NaNs. + :obj:`dpnp.amin` : The minimum value of an array along a given axis, propagates NaNs. + :obj:`dpnp.nanmin` : The minimum value of an array along a given axis, ignores NaNs. + Examples -------- >>> import dpnp as np >>> a = np.arange(4).reshape((2,2)) - >>> a.shape - (2, 2) - >>> [i for i in a] - [0, 1, 2, 3] + >>> a + array([[0, 1], + [2, 3]]) >>> np.min(a) - 0 + array(0) + + >>> np.min(a, axis=0) # Minima along the first axis + array([0, 1]) + >>> np.min(a, axis=1) # Minima along the second axis + array([0, 2]) + + >>> b = np.arange(5, dtype=float) + >>> b[2] = np.NaN + >>> np.min(b) + array(nan) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if out is not None: - pass - elif keepdims: - pass - elif initial is not None: - pass - elif where is not True: - pass + if initial is not None: + raise NotImplementedError( + "initial keyword argument is only supported by its default value." + ) + elif where is not True: + raise NotImplementedError( + "where keyword argument is only supported by its default values." + ) + else: + dpt_array = dpnp.get_usm_ndarray(a) + if dpt_array.size == 0: + # TODO: get rid of this if condition when dpctl supports it + for i in range(a.ndim): + if a.shape[i] == 0: + if axis is None or i in axis: + raise ValueError( + "reduction does not support zero-size arrays" + ) + else: + indices = [i for i in range(a.ndim) if i not in axis] + res_shape = tuple([a.shape[i] for i in indices]) + result = dpnp.empty(res_shape, dtype=a.dtype) else: - result_obj = dpnp_min(x1_desc, axis).get_pyobj() - result = dpnp.convert_single_elem_array_to_scalar(result_obj) - + result = dpnp_array._create_from_usm_ndarray( + dpt.min(dpt_array, axis=axis, keepdims=keepdims) + ) + if out is None: return result + else: + if out.shape != result.shape: + raise ValueError( + f"Output array of shape {result.shape} is needed, got {out.shape}." + ) + elif not isinstance(out, dpnp_array): + if isinstance(out, dpt.usm_ndarray): + out = dpnp_array._create_from_usm_ndarray(out) + else: + raise TypeError( + "Output array must be any of supported type, but got {}".format( + type(out) + ) + ) - return call_origin(numpy.min, x1, axis, out, keepdims, initial, where) + dpnp.copyto(out, result, casting="safe") + + return out def nanvar(x1, axis=None, dtype=None, out=None, ddof=0, keepdims=False): @@ -619,7 +704,7 @@ def nanvar(x1, axis=None, dtype=None, out=None, ddof=0, keepdims=False): Parameter `axis` is supported only with default value ``None``. Parameter `dtype` is supported only with default value ``None``. Parameter `out` is supported only with default value ``None``. - Parameter `keepdims` is supported only with default value ``numpy._NoValue``. + Parameter `keepdims` is supported only with default value ``False``. Otherwise the function will be executed sequentially on CPU. """ @@ -665,7 +750,7 @@ def std(x1, axis=None, dtype=None, out=None, ddof=0, keepdims=False): Parameter `axis` is supported only with default value ``None``. Parameter `dtype` is supported only with default value ``None``. Parameter `out` is supported only with default value ``None``. - Parameter `keepdims` is supported only with default value ``numpy._NoValue``. + Parameter `keepdims` is supported only with default value ``False``. Otherwise the function will be executed sequentially on CPU. Input array data types are limited by supported DPNP :ref:`Data types`. @@ -723,7 +808,7 @@ def var(x1, axis=None, dtype=None, out=None, ddof=0, keepdims=False): Parameter `axis` is supported only with default value ``None``. Parameter `dtype` is supported only with default value ``None``. Parameter `out` is supported only with default value ``None``. - Parameter `keepdims` is supported only with default value ``numpy._NoValue``. + Parameter `keepdims` is supported only with default value ``False``. Otherwise the function will be executed sequentially on CPU. Input array data types are limited by supported DPNP :ref:`Data types`. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 801b29fdf3c..48490d92f38 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -118,7 +118,6 @@ tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatte tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order_copied tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order_transposed -tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_min_nan tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_all tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_all_keepdims tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_axis0 @@ -796,14 +795,10 @@ tests/third_party/cupy/random_tests/test_sample.py::TestMultinomial_param_4_{siz tests/third_party/cupy/random_tests/test_sample.py::TestRandint2::test_bound_float1 tests/third_party/cupy/random_tests/test_sample.py::TestRandint2::test_goodness_of_fit tests/third_party/cupy/random_tests/test_sample.py::TestRandint2::test_goodness_of_fit_2 - tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bound_1 tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bound_2 tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2 -tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers::test_high_is_none -tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers::test_normal -tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers::test_size_is_not_none tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype_param_0_{func='argmin', is_module=True, shape=(3, 4)}::test_argminmax_dtype tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype_param_1_{func='argmin', is_module=True, shape=()}::test_argminmax_dtype diff --git a/tests/test_amin_amax.py b/tests/test_amin_amax.py index 7c5bb8b1b50..5e197f5bf13 100644 --- a/tests/test_amin_amax.py +++ b/tests/test_amin_amax.py @@ -7,7 +7,7 @@ from .helper import get_all_dtypes -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) +@pytest.mark.parametrize("dtype", get_all_dtypes()) def test_amax(dtype): a = numpy.array( [ @@ -25,7 +25,7 @@ def test_amax(dtype): assert_allclose(expected, result) -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) +@pytest.mark.parametrize("dtype", get_all_dtypes()) def test_amin(dtype): a = numpy.array( [ @@ -55,8 +55,7 @@ def _get_min_max_input(type, shape): return a.reshape(shape) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) +@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize( "shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2,3)", "(4,5,6)"] ) @@ -74,8 +73,7 @@ def test_amax_diff_shape(dtype, shape): numpy.testing.assert_array_equal(dpnp_res, np_res) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) +@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize( "shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2,3)", "(4,5,6)"] ) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 4020c6c21d7..2894f24a37b 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -1,3 +1,4 @@ +import dpctl.tensor as dpt import numpy import pytest from numpy.testing import assert_allclose @@ -21,20 +22,68 @@ def test_median(dtype, size): assert_allclose(dpnp_res, np_res) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("axis", [0, 1, -1, 2, -2, (1, 2), (0, -2)]) -@pytest.mark.parametrize( - "dtype", get_all_dtypes(no_none=True, no_bool=True, no_complex=True) -) -def test_max(axis, dtype): +@pytest.mark.parametrize("func", ["max", "min"]) +@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)]) +@pytest.mark.parametrize("keepdims", [False, True]) +@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) +def test_max_min(func, axis, keepdims, dtype): a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8)) ia = dpnp.array(a) - np_res = numpy.max(a, axis=axis) - dpnp_res = dpnp.max(ia, axis=axis) + np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) + dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) + + assert dpnp_res.shape == np_res.shape + assert_allclose(dpnp_res, np_res) + + +@pytest.mark.parametrize("func", ["max", "min"]) +@pytest.mark.parametrize("axis", [None, 0, 1, -1]) +@pytest.mark.parametrize("keepdims", [False, True]) +def test_max_min_bool(func, axis, keepdims): + a = numpy.arange(2, dtype=dpnp.bool) + a = numpy.tile(a, (2, 2)) + ia = dpnp.array(a) + + np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) + dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) + + assert dpnp_res.shape == np_res.shape + assert_allclose(dpnp_res, np_res) + + +@pytest.mark.parametrize("func", ["max", "min"]) +def test_max_min_out(func): + a = numpy.arange(6).reshape((2, 3)) + ia = dpnp.array(a) + + np_res = getattr(numpy, func)(a, axis=0) + dpnp_res = dpnp.array(numpy.empty_like(np_res)) + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) + assert_allclose(dpnp_res, np_res) + dpnp_res = dpt.asarray(numpy.empty_like(np_res)) + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) assert_allclose(dpnp_res, np_res) + dpnp_res = numpy.empty_like(np_res) + with pytest.raises(TypeError): + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) + + dpnp_res = dpnp.array(numpy.empty((2, 3))) + with pytest.raises(ValueError): + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) + + +@pytest.mark.parametrize("func", ["max", "min"]) +def test_max_min_NotImplemented(func): + ia = dpnp.arange(5) + + with pytest.raises(NotImplementedError): + getattr(dpnp, func)(ia, where=False) + with pytest.raises(NotImplementedError): + getattr(dpnp, func)(ia, initial=6) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize( diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 44852b5f513..3c131a6462f 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -260,6 +260,8 @@ def test_meshgrid(device_x, device_y): pytest.param("log10", [1.0, 2.0, 4.0, 7.0]), pytest.param("log1p", [1.0e-10, 1.0, 2.0, 4.0, 7.0]), pytest.param("log2", [1.0, 2.0, 4.0, 7.0]), + pytest.param("max", [1.0, 2.0, 4.0, 7.0]), + pytest.param("min", [1.0, 2.0, 4.0, 7.0]), pytest.param("nancumprod", [1.0, dpnp.nan]), pytest.param("nancumsum", [1.0, dpnp.nan]), pytest.param("nanprod", [1.0, dpnp.nan]), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 206cae64326..3060edd4bea 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -350,6 +350,8 @@ def test_meshgrid(usm_type_x, usm_type_y): pytest.param("log1p", [1.0e-10, 1.0, 2.0, 4.0, 7.0]), pytest.param("log2", [1.0, 2.0, 4.0, 7.0]), pytest.param("nanprod", [1.0, 2.0, dp.nan]), + pytest.param("max", [1.0, 2.0, 4.0, 7.0]), + pytest.param("min", [1.0, 2.0, 4.0, 7.0]), pytest.param("negative", [1.0, 0.0, -1.0]), pytest.param("positive", [1.0, 0.0, -1.0]), pytest.param("prod", [1.0, 2.0]), diff --git a/tests/third_party/cupy/core_tests/test_ndarray_reduction.py b/tests/third_party/cupy/core_tests/test_ndarray_reduction.py index ceea3c6259c..952398575f1 100644 --- a/tests/third_party/cupy/core_tests/test_ndarray_reduction.py +++ b/tests/third_party/cupy/core_tests/test_ndarray_reduction.py @@ -215,6 +215,65 @@ def test_ptp_nan_imag(self, xp, dtype): return a.ptp() +@testing.parameterize( + *testing.product( + { + # TODO(leofang): make a @testing.for_all_axes decorator + "shape_and_axis": [ + ((), None), + ((0,), (0,)), + ((0, 2), (0,)), + ((0, 2), (1,)), + ((0, 2), (0, 1)), + ((2, 0), (0,)), + ((2, 0), (1,)), + ((2, 0), (0, 1)), + ((0, 2, 3), (0,)), + ((0, 2, 3), (1,)), + ((0, 2, 3), (2,)), + ((0, 2, 3), (0, 1)), + ((0, 2, 3), (1, 2)), + ((0, 2, 3), (0, 2)), + ((0, 2, 3), (0, 1, 2)), + ((2, 0, 3), (0,)), + ((2, 0, 3), (1,)), + ((2, 0, 3), (2,)), + ((2, 0, 3), (0, 1)), + ((2, 0, 3), (1, 2)), + ((2, 0, 3), (0, 2)), + ((2, 0, 3), (0, 1, 2)), + ((2, 3, 0), (0,)), + ((2, 3, 0), (1,)), + ((2, 3, 0), (2,)), + ((2, 3, 0), (0, 1)), + ((2, 3, 0), (1, 2)), + ((2, 3, 0), (0, 2)), + ((2, 3, 0), (0, 1, 2)), + ], + "order": ("C", "F"), + "func": ("min", "max"), + } + ) +) +class TestArrayReductionZeroSize: + @testing.numpy_cupy_allclose( + contiguous_check=False, accept_error=ValueError + ) + def test_zero_size(self, xp): + shape, axis = self.shape_and_axis + # NumPy only supports axis being an int + if self.func in ("argmax", "argmin"): + if axis is not None and len(axis) == 1: + axis = axis[0] + else: + pytest.skip( + f"NumPy does not support axis={axis} for {self.func}" + ) + # dtype is irrelevant here, just pick one + a = testing.shaped_random(shape, xp, xp.float32, order=self.order) + return getattr(a, self.func)(axis=axis) + + # This class compares CUB results against NumPy's @testing.parameterize( *testing.product(