diff --git a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi index f2111e4e671..2b8d63c6d2d 100644 --- a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi @@ -39,7 +39,6 @@ __all__ += [ "dpnp_ediff1d", "dpnp_fabs", "dpnp_fmod", - "dpnp_gradient", "dpnp_fmax", "dpnp_fmin", "dpnp_modf", @@ -123,36 +122,6 @@ cpdef utils.dpnp_descriptor dpnp_fmod(utils.dpnp_descriptor x1_obj, return call_fptr_2in_1out_strides(DPNP_FN_FMOD_EXT, x1_obj, x2_obj, dtype, out, where) -cpdef utils.dpnp_descriptor dpnp_gradient(utils.dpnp_descriptor y1, int dx=1): - - cdef size_t size = y1.size - - y1_obj = y1.get_array() - - # create result array with type given by FPTR data - cdef shape_type_c result_shape = utils._object_to_tuple(size) - cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(result_shape, - dpnp.default_float_type(y1_obj.sycl_queue), - None, - device=y1_obj.sycl_device, - usm_type=y1_obj.usm_type, - sycl_queue=y1_obj.sycl_queue) - - cdef double cur = (y1.get_pyobj()[1] - y1.get_pyobj()[0]) / dx - - result.get_pyobj().flat[0] = cur - - cur = (y1.get_pyobj()[-1] - y1.get_pyobj()[-2]) / dx - - result.get_pyobj().flat[size - 1] = cur - - for i in range(1, size - 1): - cur = (y1.get_pyobj()[i + 1] - y1.get_pyobj()[i - 1]) / (2 * dx) - result.get_pyobj().flat[i] = cur - - return result - - cpdef utils.dpnp_descriptor dpnp_fmax(utils.dpnp_descriptor x1_obj, utils.dpnp_descriptor x2_obj, object dtype=None, diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 716696cacbd..b0d0c7b6123 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -46,6 +46,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti import dpctl.tensor._type_utils as dtu +import dpctl.utils as dpu import numpy from dpctl.tensor._type_utils import _acceptance_fn_divide from numpy.core.numeric import ( @@ -63,7 +64,6 @@ dpnp_fmax, dpnp_fmin, dpnp_fmod, - dpnp_gradient, dpnp_modf, dpnp_trapz, ) @@ -168,6 +168,169 @@ def _get_reduction_res_dt(a, dtype, _out): return dtu._to_device_supported_dtype(dtype, a.sycl_device) +def _gradient_build_dx(f, axes, *varargs): + """Build an array with distance per each dimension.""" + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and numpy.ndim(varargs[0]) == 0: + dpnp.check_supported_arrays_type( + varargs[0], scalar_type=True, all_scalars=True + ) + + # single scalar for all axes + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + dpnp.check_supported_arrays_type( + distances, scalar_type=True, all_scalars=True + ) + + if numpy.ndim(distances) == 0: + continue + if distances.ndim != 1: + raise ValueError("distances must be either scalars or 1d") + + if len(distances) != f.shape[axes[i]]: + raise ValueError( + "when 1d, distances must match " + "the length of the corresponding dimension" + ) + + if dpnp.issubdtype(distances.dtype, dpnp.integer): + # Convert integer types to default float type to avoid modular + # arithmetic in dpnp.diff(distances). + distances = distances.astype(dpnp.default_float_type()) + diffx = dpnp.diff(distances) + + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + return dx + + +def _gradient_num_diff_2nd_order_interior( + f, ax_dx, out, slices, axis, uniform_spacing +): + """Numerical differentiation: 2nd order interior.""" + + slice1, slice2, slice3, slice4 = slices + ndim = f.ndim + + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / ( + 2.0 * ax_dx + ) + else: + dx1 = ax_dx[0:-1] + dx2 = ax_dx[1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + + # fix the shape for broadcasting + shape = [1] * ndim + shape[axis] = -1 + # TODO: use shape.setter once dpctl#1699 is resolved + # a.shape = b.shape = c.shape = shape + a = a.reshape(shape) + b = b.reshape(shape) + c = c.reshape(shape) + + # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + t1 = a * f[tuple(slice2)] + t2 = b * f[tuple(slice3)] + t3 = c * f[tuple(slice4)] + t4 = t1 + t2 + t3 + + out[tuple(slice1)] = t4 + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + +def _gradient_num_diff_edges( + f, ax_dx, out, slices, axis, uniform_spacing, edge_order +): + """Numerical differentiation: 1st and 2nd order edges.""" + + slice1, slice2, slice3, slice4 = slices + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = ax_dx if uniform_spacing else ax_dx[0] + + # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = ax_dx if uniform_spacing else ax_dx[-1] + + # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / ax_dx + b = 2.0 / ax_dx + c = -0.5 / ax_dx + else: + dx1 = ax_dx[0] + dx2 = ax_dx[1] + a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = -dx1 / (dx2 * (dx1 + dx2)) + + # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / ax_dx + b = -2.0 / ax_dx + c = 1.5 / ax_dx + else: + dx1 = ax_dx[-2] + dx2 = ax_dx[-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = -(dx2 + dx1) / (dx1 * dx2) + c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) + + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + _ABS_DOCSTRING = """ Calculates the absolute value for each element `x_i` of input array `x`. @@ -1682,51 +1845,206 @@ def fmod(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs): ) -def gradient(x1, *varargs, **kwargs): +def gradient(f, *varargs, axis=None, edge_order=1): """ - Return the gradient of an array. + Return the gradient of an N-dimensional array. + + The gradient is computed using second order accurate central differences + in the interior points and either first or second order accurate one-sides + (forward or backwards) differences at the boundaries. + The returned gradient hence has the same shape as the input array. For full documentation refer to :obj:`numpy.gradient`. - Limitations - ----------- - Parameter `y1` is supported as :class:`dpnp.ndarray`. - Argument `varargs[0]` is supported as `int`. - Keyword argument `kwargs` is currently unsupported. - Otherwise the function will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. + Parameters + ---------- + f : {dpnp.ndarray, usm_ndarray} + An N-dimensional array containing samples of a scalar function. + varargs : {scalar, list of scalars, list of arrays}, optional + Spacing between `f` values. Default unitary spacing for all dimensions. + Spacing can be specified using: + + 1. Single scalar to specify a sample distance for all dimensions. + 2. N scalars to specify a constant sample distance for each dimension. + i.e. `dx`, `dy`, `dz`, ... + 3. N arrays to specify the coordinates of the values along each + dimension of `f`. The length of the array must match the size of + the corresponding dimension + 4. Any combination of N scalars/arrays with the meaning of 2. and 3. + + If `axis` is given, the number of `varargs` must equal the number of + axes. + Default: ``1``. + axis : {None, int, tuple of ints}, optional + Gradient is calculated only along the given axis or axes. + The default is to calculate the gradient for all the axes of the input + array. `axis` may be negative, in which case it counts from the last to + the first axis. + Default: ``None``. + edge_order : {1, 2}, optional + Gradient is calculated using N-th order accurate differences + at the boundaries. + Default: ``1``. + + Returns + ------- + gradient : {dpnp.ndarray, list of ndarray} + A list of :class:`dpnp.ndarray` (or a single :class:`dpnp.ndarray` if + there is only one dimension) corresponding to the derivatives of `f` + with respect to each dimension. + Each derivative has the same shape as `f`. See Also -------- :obj:`dpnp.diff` : Calculate the n-th discrete difference along the given axis. + :obj:`dpnp.ediff1d` : Calculate the differences between consecutive + elements of an array. Examples -------- >>> import dpnp as np - >>> y = np.array([1, 2, 4, 7, 11, 16], dtype=float) - >>> result = np.gradient(y) - >>> [x for x in result] - [1.0, 1.5, 2.5, 3.5, 4.5, 5.0] - >>> result = np.gradient(y, 2) - >>> [x for x in result] - [0.5, 0.75, 1.25, 1.75, 2.25, 2.5] + >>> f = np.array([1, 2, 4, 7, 11, 16], dtype=float) + >>> np.gradient(f) + array([1. , 1.5, 2.5, 3.5, 4.5, 5. ]) + >>> np.gradient(f, 2) + array([0.5 , 0.75, 1.25, 1.75, 2.25, 2.5 ]) + + Spacing can be also specified with an array that represents the coordinates + of the values `f` along the dimensions. + For instance a uniform spacing: + + >>> x = np.arange(f.size) + >>> np.gradient(f, x) + array([1. , 1.5, 2.5, 3.5, 4.5, 5. ]) + + Or a non uniform one: + + >>> x = np.array([0., 1., 1.5, 3.5, 4., 6.], dtype=float) + >>> np.gradient(f, x) + array([1. , 3. , 3.5, 6.7, 6.9, 2.5]) + + For two dimensional arrays, the return will be two arrays ordered by + axis. In this example the first array stands for the gradient in + rows and the second one in columns direction: + + >>> np.gradient(np.array([[1, 2, 6], [3, 4, 5]], dtype=float)) + (array([[ 2., 2., -1.], + [ 2., 2., -1.]]), + array([[1. , 2.5, 4. ], + [1. , 1. , 1. ]])) + + In this example the spacing is also specified: + uniform for axis=0 and non uniform for axis=1 + + >>> dx = 2. + >>> y = np.array([1., 1.5, 3.5]) + >>> np.gradient(np.array([[1, 2, 6], [3, 4, 5]], dtype=float), dx, y) + (array([[ 1. , 1. , -0.5], + [ 1. , 1. , -0.5]]), + array([[2. , 2. , 2. ], + [2. , 1.7, 0.5]])) + + It is possible to specify how boundaries are treated using `edge_order` + + >>> x = np.array([0, 1, 2, 3, 4]) + >>> f = x**2 + >>> np.gradient(f, edge_order=1) + array([1., 2., 4., 6., 7.]) + >>> np.gradient(f, edge_order=2) + array([0., 2., 4., 6., 8.]) + + The `axis` keyword can be used to specify a subset of axes of which the + gradient is calculated + + >>> np.gradient(np.array([[1, 2, 6], [3, 4, 5]], dtype=float), axis=0) + array([[ 2., 2., -1.], + [ 2., 2., -1.]]) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc and not kwargs: - if len(varargs) > 1: - pass - elif len(varargs) == 1 and not isinstance(varargs[0], int): - pass + dpnp.check_supported_arrays_type(f) + ndim = f.ndim # number of dimensions + + if axis is None: + axes = tuple(range(ndim)) + else: + axes = normalize_axis_tuple(axis, ndim) + + dx = _gradient_build_dx(f, axes, *varargs) + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # Use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * ndim + slice2 = [slice(None)] * ndim + slice3 = [slice(None)] * ndim + slice4 = [slice(None)] * ndim + + otype = f.dtype + if dpnp.issubdtype(otype, dpnp.inexact): + pass + else: + # All other types convert to floating point. + # First check if f is a dpnp integer type; if so, convert f to default + # float type to avoid modular arithmetic when computing changes in f. + if dpnp.issubdtype(otype, dpnp.integer): + f = f.astype(dpnp.default_float_type()) + otype = dpnp.default_float_type() + + for axis_, ax_dx in zip(axes, dx): + if f.shape[axis_] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical gradient, " + "at least (edge_order + 1) elements are required." + ) + + # result allocation + if dpnp.isscalar(ax_dx): + usm_type = f.usm_type else: - if len(varargs) == 0: - return dpnp_gradient(x1_desc).get_pyobj() + usm_type = dpu.get_coerced_usm_type([f.usm_type, ax_dx.usm_type]) + out = dpnp.empty_like(f, dtype=otype, usm_type=usm_type) + + # spacing for the current axis + uniform_spacing = numpy.ndim(ax_dx) == 0 + + # Numerical differentiation: 2nd order interior + _gradient_num_diff_2nd_order_interior( + f, + ax_dx, + out, + (slice1, slice2, slice3, slice4), + axis_, + uniform_spacing, + ) + + # Numerical differentiation: 1st and 2nd order edges + _gradient_num_diff_edges( + f, + ax_dx, + out, + (slice1, slice2, slice3, slice4), + axis_, + uniform_spacing, + edge_order, + ) + + outvals.append(out) - return dpnp_gradient(x1_desc, varargs[0]).get_pyobj() + # reset the slice object in this dimension to ":" + slice1[axis_] = slice(None) + slice2[axis_] = slice(None) + slice3[axis_] = slice(None) + slice4[axis_] = slice(None) - return call_origin(numpy.gradient, x1, *varargs, **kwargs) + if len(axes) == 1: + return outvals[0] + return tuple(outvals) _IMAG_DOCSTRING = """ diff --git a/tests/skipped_tests_gpu_no_fp64.tbl b/tests/skipped_tests_gpu_no_fp64.tbl index 7a999c99617..c209c876df6 100644 --- a/tests/skipped_tests_gpu_no_fp64.tbl +++ b/tests/skipped_tests_gpu_no_fp64.tbl @@ -1,7 +1,3 @@ -tests/test_mathematical.py::TestGradient::test_gradient_y1_dx[3.5-array0] -tests/test_mathematical.py::TestGradient::test_gradient_y1_dx[3.5-array1] -tests/test_mathematical.py::TestGradient::test_gradient_y1_dx[3.5-array2] - tests/test_strides.py::test_strides_1arg[(10,)-int32-fabs] tests/test_strides.py::test_strides_1arg[(10,)-int64-fabs] tests/test_strides.py::test_strides_1arg[(10,)-None-fabs] diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 69b590b386c..4a86cdc081e 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -9,6 +9,7 @@ assert_array_equal, assert_equal, assert_raises, + assert_raises_regex, ) import dpnp @@ -23,7 +24,6 @@ get_float_dtypes, get_integer_dtypes, has_support_aspect64, - is_cpu_device, ) from .test_umath import ( _get_numpy_arrays_1in_1out, @@ -73,6 +73,35 @@ def test_angle_complex(self, dtype, deg): assert_dtype_allclose(result, expected) +@pytest.mark.usefixtures("allow_fall_back_on_numpy") +class TestConvolve: + def test_object(self): + d = [1.0] * 100 + k = [1.0] * 3 + assert_array_almost_equal(dpnp.convolve(d, k)[2:-2], dpnp.full(98, 3)) + + def test_no_overwrite(self): + d = dpnp.ones(100) + k = dpnp.ones(3) + dpnp.convolve(d, k) + assert_array_equal(d, dpnp.ones(100)) + assert_array_equal(k, dpnp.ones(3)) + + def test_mode(self): + d = dpnp.ones(100) + k = dpnp.ones(3) + default_mode = dpnp.convolve(d, k, mode="full") + full_mode = dpnp.convolve(d, k, mode="f") + assert_array_equal(full_mode, default_mode) + # integer mode + with assert_raises(ValueError): + dpnp.convolve(d, k, mode=-1) + assert_array_equal(dpnp.convolve(d, k, mode=2), full_mode) + # illegal arguments + with assert_raises(TypeError): + dpnp.convolve(d, k, mode=None) + + class TestClip: @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) @@ -582,33 +611,347 @@ def test_prepend_append_axis_error(self, xp): assert_raises(numpy.AxisError, xp.diff, a, axis=3, append=0) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestConvolve: - def test_object(self): - d = [1.0] * 100 - k = [1.0] * 3 - assert_array_almost_equal(dpnp.convolve(d, k)[2:-2], dpnp.full(98, 3)) +class TestGradient: + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True)) + def test_basic(self, dt): + x = numpy.array([[1, 1], [3, 4]], dtype=dt) + ix = dpnp.array(x) - def test_no_overwrite(self): - d = dpnp.ones(100) - k = dpnp.ones(3) - dpnp.convolve(d, k) - assert_array_equal(d, dpnp.ones(100)) - assert_array_equal(k, dpnp.ones(3)) + expected = numpy.gradient(x) + result = dpnp.gradient(ix) + assert_array_equal(result, expected) - def test_mode(self): - d = dpnp.ones(100) - k = dpnp.ones(3) - default_mode = dpnp.convolve(d, k, mode="full") - full_mode = dpnp.convolve(d, k, mode="f") - assert_array_equal(full_mode, default_mode) - # integer mode - with assert_raises(ValueError): - dpnp.convolve(d, k, mode=-1) - assert_array_equal(dpnp.convolve(d, k, mode=2), full_mode) - # illegal arguments - with assert_raises(TypeError): - dpnp.convolve(d, k, mode=None) + @pytest.mark.parametrize( + "args", + [3.0, numpy.array(3.0), numpy.cumsum(numpy.ones(5))], + ids=["scalar", "array", "cumsum"], + ) + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True)) + def test_args_1d(self, args, dt): + x = numpy.arange(5, dtype=dt) + ix = dpnp.array(x) + + if numpy.isscalar(args): + iargs = args + else: + iargs = dpnp.array(args) + + expected = numpy.gradient(x, args) + result = dpnp.gradient(ix, iargs) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "args", [1.5, numpy.array(1.5)], ids=["scalar", "array"] + ) + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True)) + def test_args_2d(self, args, dt): + x = numpy.arange(25, dtype=dt).reshape(5, 5) + ix = dpnp.array(x) + + if numpy.isscalar(args): + iargs = args + else: + iargs = dpnp.array(args) + + expected = numpy.gradient(x, args) + result = dpnp.gradient(ix, iargs) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True)) + def test_args_2d_uneven(self, dt): + x = numpy.arange(25, dtype=dt).reshape(5, 5) + ix = dpnp.array(x) + + dx = numpy.array([1.0, 2.0, 5.0, 9.0, 11.0]) + idx = dpnp.array(dx) + + expected = numpy.gradient(x, dx, dx) + result = dpnp.gradient(ix, idx, idx) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True)) + def test_args_2d_mix_with_scalar(self, dt): + x = numpy.arange(25, dtype=dt).reshape(5, 5) + ix = dpnp.array(x) + + dx = numpy.cumsum(numpy.ones(5)) + idx = dpnp.array(dx) + + expected = numpy.gradient(x, dx, 2) + result = dpnp.gradient(ix, idx, 2) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True)) + def test_axis_args_2d(self, dt): + x = numpy.arange(25, dtype=dt).reshape(5, 5) + ix = dpnp.array(x) + + dx = numpy.cumsum(numpy.ones(5)) + idx = dpnp.array(dx) + + expected = numpy.gradient(x, dx, axis=1) + result = dpnp.gradient(ix, idx, axis=1) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_args_2d_error(self, xp): + x = xp.arange(25).reshape(5, 5) + dx = xp.cumsum(xp.ones(5)) + assert_raises_regex( + ValueError, + ".*scalars or 1d", + xp.gradient, + x, + xp.stack([dx] * 2, axis=-1), + 1, + ) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_badargs(self, xp): + x = xp.arange(25).reshape(5, 5) + dx = xp.cumsum(xp.ones(5)) + + # wrong sizes + assert_raises(ValueError, xp.gradient, x, x, xp.ones(2)) + assert_raises(ValueError, xp.gradient, x, 1, xp.ones(2)) + assert_raises(ValueError, xp.gradient, x, xp.ones(2), xp.ones(2)) + # wrong number of arguments + assert_raises(TypeError, xp.gradient, x, x) + assert_raises(TypeError, xp.gradient, x, dx, axis=(0, 1)) + assert_raises(TypeError, xp.gradient, x, dx, dx, dx) + assert_raises(TypeError, xp.gradient, x, 1, 1, 1) + assert_raises(TypeError, xp.gradient, x, dx, dx, axis=1) + assert_raises(TypeError, xp.gradient, x, 1, 1, axis=1) + + @pytest.mark.parametrize( + "x", + [ + numpy.linspace(0, 1, 10), + numpy.sort(numpy.random.RandomState(0).random(10)), + ], + ids=["linspace", "random_sorted"], + ) + @pytest.mark.parametrize("dt", get_float_dtypes()) + # testing that the relative numerical error is close to numpy + def test_second_order_accurate(self, x, dt): + x = x.astype(dt) + dx = x[1] - x[0] + y = 2 * x**3 + 4 * x**2 + 2 * x + + iy = dpnp.array(y) + idx = dpnp.array(dx) + + expected = numpy.gradient(y, dx, edge_order=2) + result = dpnp.gradient(iy, idx, edge_order=2) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("edge_order", [1, 2]) + @pytest.mark.parametrize("axis", [0, 1, (0, 1)]) + @pytest.mark.parametrize("dt", get_float_dtypes()) + def test_spacing_axis_scalar(self, edge_order, axis, dt): + x = numpy.array([0, 2.0, 3.0, 4.0, 5.0, 5.0], dtype=dt) + x = numpy.tile(x, (6, 1)) + x.reshape(-1, 1) + ix = dpnp.array(x) + + expected = numpy.gradient(x, 1.0, axis=axis, edge_order=edge_order) + result = dpnp.gradient(ix, 1.0, axis=axis, edge_order=edge_order) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("edge_order", [1, 2]) + @pytest.mark.parametrize("axis", [(0, 1), None]) + @pytest.mark.parametrize("dt", get_float_dtypes()) + @pytest.mark.parametrize( + "dx", + [numpy.arange(6.0), numpy.array([0.0, 0.5, 1.0, 3.0, 5.0, 7.0])], + ids=["even", "uneven"], + ) + def test_spacing_axis_two_args(self, edge_order, axis, dt, dx): + x = numpy.array([0, 2.0, 3.0, 4.0, 5.0, 5.0], dtype=dt) + x = numpy.tile(x, (6, 1)) + x.reshape(-1, 1) + + ix = dpnp.array(x) + idx = dpnp.array(dx) + + expected = numpy.gradient(x, dx, dx, axis=axis, edge_order=edge_order) + result = dpnp.gradient(ix, idx, idx, axis=axis, edge_order=edge_order) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("edge_order", [1, 2]) + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize("dt", get_float_dtypes()) + @pytest.mark.parametrize( + "dx", + [numpy.arange(6.0), numpy.array([0.0, 0.5, 1.0, 3.0, 5.0, 7.0])], + ids=["even", "uneven"], + ) + def test_spacing_axis_args(self, edge_order, axis, dt, dx): + x = numpy.array([0, 2.0, 3.0, 4.0, 5.0, 5.0], dtype=dt) + x = numpy.tile(x, (6, 1)) + x.reshape(-1, 1) + + ix = dpnp.array(x) + idx = dpnp.array(dx) + + expected = numpy.gradient(x, dx, axis=axis, edge_order=edge_order) + result = dpnp.gradient(ix, idx, axis=axis, edge_order=edge_order) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("edge_order", [1, 2]) + @pytest.mark.parametrize("dt", get_float_dtypes()) + def test_spacing_mix_args(self, edge_order, dt): + x = numpy.array([0, 2.0, 3.0, 4.0, 5.0, 5.0], dtype=dt) + x = numpy.tile(x, (6, 1)) + x.reshape(-1, 1) + x_uneven = numpy.array([0.0, 0.5, 1.0, 3.0, 5.0, 7.0]) + x_even = numpy.arange(6.0) + + ix = dpnp.array(x) + ix_uneven = dpnp.array(x_uneven) + ix_even = dpnp.array(x_even) + + expected = numpy.gradient( + x, x_even, x_uneven, axis=(0, 1), edge_order=edge_order + ) + result = dpnp.gradient( + ix, ix_even, ix_uneven, axis=(0, 1), edge_order=edge_order + ) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + expected = numpy.gradient( + x, x_uneven, x_even, axis=(1, 0), edge_order=edge_order + ) + result = dpnp.gradient( + ix, ix_uneven, ix_even, axis=(1, 0), edge_order=edge_order + ) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("axis", [0, 1, -1, (1, 0), None]) + def test_specific_axes(self, axis): + x = numpy.array([[1, 1], [3, 4]]) + ix = dpnp.array(x) + + expected = numpy.gradient(x, axis=axis) + result = dpnp.gradient(ix, axis=axis) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + def test_axis_scalar_args(self): + x = numpy.array([[1, 1], [3, 4]]) + ix = dpnp.array(x) + + expected = numpy.gradient(x, 2, 3, axis=(1, 0)) + result = dpnp.gradient(ix, 2, 3, axis=(1, 0)) + for gr, igr in zip(expected, result): + assert_dtype_allclose(igr, gr) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_wrong_number_of_args(self, xp): + x = xp.array([[1, 1], [3, 4]]) + assert_raises(TypeError, xp.gradient, x, 1, 2, axis=1) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_wrong_axis(self, xp): + x = xp.array([[1, 1], [3, 4]]) + assert_raises(numpy.AxisError, xp.gradient, x, axis=3) + + @pytest.mark.parametrize( + "size, edge_order", + [ + pytest.param(2, 1), + pytest.param(3, 2), + ], + ) + def test_min_size_with_edge_order(self, size, edge_order): + x = numpy.arange(size) + ix = dpnp.array(x) + + expected = numpy.gradient(x, edge_order=edge_order) + result = dpnp.gradient(ix, edge_order=edge_order) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "size, edge_order", + [ + pytest.param(0, 1), + pytest.param(0, 2), + pytest.param(1, 1), + pytest.param(1, 2), + pytest.param(2, 2), + ], + ) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_wrong_size_with_edge_order(self, size, edge_order, xp): + assert_raises( + ValueError, xp.gradient, xp.arange(size), edge_order=edge_order + ) + + @pytest.mark.parametrize( + "dt", [numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64] + ) + def test_f_decreasing_unsigned_int(self, dt): + x = numpy.array([5, 4, 3, 2, 1], dtype=dt) + ix = dpnp.array(x) + + expected = numpy.gradient(x) + result = dpnp.gradient(ix) + assert_array_equal(result, expected) + + @pytest.mark.parametrize( + "dt", [numpy.int8, numpy.int16, numpy.int32, numpy.int64] + ) + def test_f_signed_int_big_jump(self, dt): + maxint = numpy.iinfo(dt).max + x = numpy.array([-1, maxint], dtype=dt) + dx = numpy.array([1, 3]) + + ix = dpnp.array(x) + idx = dpnp.array(dx) + + expected = numpy.gradient(x, dx) + result = dpnp.gradient(ix, idx) + assert_array_equal(result, expected) + + @pytest.mark.parametrize( + "dt", [numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64] + ) + def test_x_decreasing_unsigned(self, dt): + x = numpy.array([3, 2, 1], dtype=dt) + f = numpy.array([0, 2, 4]) + + dp_x = dpnp.array(x) + dp_f = dpnp.array(f) + + expected = numpy.gradient(f, x) + result = dpnp.gradient(dp_f, dp_x) + assert_array_equal(result, expected) + + @pytest.mark.parametrize( + "dt", [numpy.int8, numpy.int16, numpy.int32, numpy.int64] + ) + def test_x_signed_int_big_jump(self, dt): + minint = numpy.iinfo(dt).min + maxint = numpy.iinfo(dt).max + x = numpy.array([-1, maxint], dtype=dt) + f = numpy.array([minint // 2, 0]) + + dp_x = dpnp.array(x) + dp_f = dpnp.array(f) + + expected = numpy.gradient(f, x) + result = dpnp.gradient(dp_f, dp_x) + assert_array_equal(result, expected) + + def test_return_type(self): + x = dpnp.array([[1, 2], [2, 3]]) + res = dpnp.gradient(x) + assert type(res) is tuple @pytest.mark.parametrize("dtype1", get_all_dtypes()) @@ -1384,32 +1727,6 @@ def test_trapz_with_dx_params(self, y_array, dx): assert_array_equal(expected, result) -class TestGradient: - @pytest.mark.parametrize( - "array", [[2, 3, 6, 8, 4, 9], [3.0, 4.0, 7.5, 9.0], [2, 6, 8, 10]] - ) - def test_gradient_y1(self, array): - np_y = numpy.array(array) - dpnp_y = dpnp.array(array) - - result = dpnp.gradient(dpnp_y) - expected = numpy.gradient(np_y) - assert_array_equal(expected, result) - - @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @pytest.mark.parametrize( - "array", [[2, 3, 6, 8, 4, 9], [3.0, 4.0, 7.5, 9.0], [2, 6, 8, 10]] - ) - @pytest.mark.parametrize("dx", [2, 3.5]) - def test_gradient_y1_dx(self, array, dx): - np_y = numpy.array(array) - dpnp_y = dpnp.array(array) - - result = dpnp.gradient(dpnp_y, dx) - expected = numpy.gradient(np_y, dx) - assert_array_equal(expected, result) - - class TestRoundingFuncs: @pytest.fixture( params=[ diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index fae4dd52221..e66c1a55b87 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -625,6 +625,11 @@ def test_reduce_hypot(device): [-3.0, -2.0, -1.0, 1.0, 2.0, 3.0], [2.0, 2.0, 2.0, 2.0, 2.0, 2.0], ), + pytest.param( + "gradient", + [1.0, 2.0, 4.0, 7.0, 11.0, 16.0], + [0.0, 1.0, 1.5, 3.5, 4.0, 6.0], + ), pytest.param( "histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5], @@ -691,7 +696,7 @@ def test_2in_1out(func, data1, data2, device): x2 = dpnp.array(data2, device=device) result = getattr(dpnp, func)(x1, x2) - assert_allclose(result, expected) + assert_dtype_allclose(result, expected) assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue) assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index eab59cf001b..f42b6a769bc 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -539,6 +539,7 @@ def test_norm(usm_type, ord, axis): pytest.param("exp2", [0.0, 1.0, 2.0]), pytest.param("expm1", [1.0e-10, 1.0, 2.0, 4.0, 7.0]), pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), + pytest.param("gradient", [1, 2, 4, 7, 11, 16]), pytest.param("histogram_bin_edges", [0, 0, 0, 1, 2, 3, 3, 4, 5]), pytest.param( "imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] @@ -622,6 +623,9 @@ def test_1in_1out(func, data, usm_type): pytest.param("dot", [3 + 2j, 4 + 1j, 5], [1, 2 + 3j, 3]), pytest.param("fmax", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), pytest.param("fmin", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), + pytest.param( + "gradient", [1, 2, 4, 7, 11, 16], [0.0, 1.0, 1.5, 3.5, 4.0, 6.0] + ), pytest.param( "hypot", [[1.0, 2.0, 3.0, 4.0]], [[-1.0, -2.0, -4.0, -5.0]] ), diff --git a/tests/third_party/cupy/math_tests/test_sumprod.py b/tests/third_party/cupy/math_tests/test_sumprod.py index 18a74a76330..f36086755e9 100644 --- a/tests/third_party/cupy/math_tests/test_sumprod.py +++ b/tests/third_party/cupy/math_tests/test_sumprod.py @@ -717,9 +717,15 @@ def test_diff_invalid_axis(self): ), ) ) -@pytest.mark.skip("gradient() is not implemented yet") class TestGradient: def _gradient(self, xp, dtype, shape, spacing, axis, edge_order): + if ( + not has_support_aspect64() + and shape == (10, 20, 30) + and spacing == "arrays" + ): + pytest.skip("too big values") + x = testing.shaped_random(shape, xp, dtype=dtype) if axis is None: normalized_axes = tuple(range(x.ndim)) @@ -755,7 +761,9 @@ def test_gradient_floating(self, xp, dtype): # https://github.com/numpy/numpy/issues/15207 @testing.with_requires("numpy>=1.18.1") @testing.for_int_dtypes(no_bool=True) - @testing.numpy_cupy_allclose(atol=1e-6, rtol=1e-5) + @testing.numpy_cupy_allclose( + atol=1e-6, rtol=1e-5, type_check=has_support_aspect64() + ) def test_gradient_int(self, xp, dtype): return self._gradient( xp, dtype, self.shape, self.spacing, self.axis, self.edge_order @@ -773,7 +781,6 @@ def test_gradient_float16(self, xp): ) -@pytest.mark.skip("gradient() is not implemented yet") class TestGradientErrors: def test_gradient_invalid_spacings1(self): # more spacings than axes