diff --git a/dpnp/dpnp_algo/dpnp_algo_arraycreation.pxi b/dpnp/dpnp_algo/dpnp_algo_arraycreation.pxi index 5ebb8d157a7..7b90ff1285f 100644 --- a/dpnp/dpnp_algo/dpnp_algo_arraycreation.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_arraycreation.pxi @@ -38,9 +38,6 @@ and the rest of the library __all__ += [ "dpnp_copy", "dpnp_diag", - "dpnp_geomspace", - "dpnp_linspace", - "dpnp_logspace", "dpnp_ptp", "dpnp_trace", "dpnp_vander", @@ -138,116 +135,6 @@ cpdef utils.dpnp_descriptor dpnp_diag(utils.dpnp_descriptor v, int k): return result -cpdef utils.dpnp_descriptor dpnp_geomspace(start, stop, num, endpoint, dtype, axis): - cdef shape_type_c obj_shape = utils._object_to_tuple(num) - cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(obj_shape, dtype, None) - - if endpoint: - steps_count = num - 1 - else: - steps_count = num - - # if there are steps, then fill values - if steps_count > 0: - step = dpnp.power(dpnp.float64(stop) / start, 1.0 / steps_count) - mult = step - for i in range(1, result.size): - result.get_pyobj()[i] = start * mult - mult = mult * step - else: - step = dpnp.nan - - # if result is not empty, then fiil first and last elements - if num > 0: - result.get_pyobj()[0] = start - if endpoint and result.size > 1: - result.get_pyobj()[result.size - 1] = stop - - return result - - -def dpnp_linspace(start, stop, num, dtype=None, device=None, usm_type=None, sycl_queue=None, endpoint=True, retstep=False, axis=0): - usm_type_alloc, sycl_queue_alloc = utils_py.get_usm_allocations([start, stop]) - - # Get sycl_queue. - if sycl_queue is None and device is None: - sycl_queue = sycl_queue_alloc - sycl_queue_normalized = dpnp.get_normalized_queue_device(sycl_queue=sycl_queue, device=device) - - # Get temporary usm_type for getting dtype. - if usm_type is None: - _usm_type = "device" if usm_type_alloc is None else usm_type_alloc - else: - _usm_type = usm_type - - # Get dtype. - if not hasattr(start, "dtype") and not dpnp.isscalar(start): - start = dpnp.asarray(start, usm_type=_usm_type, sycl_queue=sycl_queue_normalized) - if not hasattr(stop, "dtype") and not dpnp.isscalar(stop): - stop = dpnp.asarray(stop, usm_type=_usm_type, sycl_queue=sycl_queue_normalized) - dt = numpy.result_type(start, stop, float(num)) - dt = utils_py.map_dtype_to_device(dt, sycl_queue_normalized.sycl_device) - if dtype is None: - dtype = dt - - if dpnp.isscalar(start) and dpnp.isscalar(stop): - # Call linspace() function for scalars. - res = dpnp_container.linspace(start, - stop, - num, - dtype=dt, - usm_type=_usm_type, - sycl_queue=sycl_queue_normalized, - endpoint=endpoint) - else: - num = operator.index(num) - if num < 0: - raise ValueError("Number of points must be non-negative") - - # Get final usm_type and copy arrays if needed with current dtype, usm_type and sycl_queue. - # Do not need to copy usm_ndarray by usm_type if it is not explicitly stated. - if usm_type is None: - usm_type = _usm_type - if not hasattr(start, "usm_type"): - _start = dpnp.asarray(start, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized) - else: - _start = dpnp.asarray(start, dtype=dt, sycl_queue=sycl_queue_normalized) - if not hasattr(stop, "usm_type"): - _stop = dpnp.asarray(stop, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized) - else: - _stop = dpnp.asarray(stop, dtype=dt, sycl_queue=sycl_queue_normalized) - else: - _start = dpnp.asarray(start, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized) - _stop = dpnp.asarray(stop, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized) - - # FIXME: issue #1304. Mathematical operations with scalar don't follow data type. - _num = dpnp.asarray((num - 1) if endpoint else num, dtype=dt, usm_type=usm_type, sycl_queue=sycl_queue_normalized) - - step = (_stop - _start) / _num - - res = dpnp_container.arange(0, - stop=num, - step=1, - dtype=dt, - usm_type=usm_type, - sycl_queue=sycl_queue_normalized) - - res = res.reshape((-1,) + (1,) * step.ndim) - res = res * step + _start - - if endpoint and num > 1: - res[-1] = dpnp_container.full(step.shape, _stop) - - if numpy.issubdtype(dtype, dpnp.integer): - dpnp.floor(res, out=res) - return res.astype(dtype) - - -cpdef utils.dpnp_descriptor dpnp_logspace(start, stop, num, endpoint, base, dtype, axis): - temp = dpnp.linspace(start, stop, num=num, endpoint=endpoint) - return dpnp.get_dpnp_descriptor(dpnp.astype(dpnp.power(base, temp), dtype)) - - cpdef dpnp_ptp(utils.dpnp_descriptor arr, axis=None): cdef shape_type_c shape_arr = arr.shape cdef shape_type_c output_shape diff --git a/dpnp/dpnp_algo/dpnp_arraycreation.py b/dpnp/dpnp_algo/dpnp_arraycreation.py new file mode 100644 index 00000000000..b6d3c612068 --- /dev/null +++ b/dpnp/dpnp_algo/dpnp_arraycreation.py @@ -0,0 +1,258 @@ +import operator + +import numpy + +import dpnp +import dpnp.dpnp_container as dpnp_container +import dpnp.dpnp_utils as utils + +__all__ = [ + "dpnp_geomspace", + "dpnp_linspace", + "dpnp_logspace", +] + + +def dpnp_geomspace( + start, + stop, + num, + dtype=None, + device=None, + usm_type=None, + sycl_queue=None, + endpoint=True, + axis=0, +): + usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations([start, stop]) + + if sycl_queue is None and device is None: + sycl_queue = sycl_queue_alloc + sycl_queue_normalized = dpnp.get_normalized_queue_device( + sycl_queue=sycl_queue, device=device + ) + + if usm_type is None: + _usm_type = "device" if usm_type_alloc is None else usm_type_alloc + else: + _usm_type = usm_type + + if not dpnp.is_supported_array_type(start): + start = dpnp.asarray( + start, usm_type=_usm_type, sycl_queue=sycl_queue_normalized + ) + if not dpnp.is_supported_array_type(stop): + stop = dpnp.asarray( + stop, usm_type=_usm_type, sycl_queue=sycl_queue_normalized + ) + + dt = numpy.result_type(start, stop, float(num)) + dt = utils.map_dtype_to_device(dt, sycl_queue_normalized.sycl_device) + if dtype is None: + dtype = dt + + if dpnp.any(start == 0) or dpnp.any(stop == 0): + raise ValueError("Geometric sequence cannot include zero") + + out_sign = dpnp.ones( + dpnp.broadcast_arrays(start, stop)[0].shape, + dtype=dt, + usm_type=_usm_type, + sycl_queue=sycl_queue_normalized, + ) + # Avoid negligible real or imaginary parts in output by rotating to + # positive real, calculating, then undoing rotation + if dpnp.issubdtype(dt, dpnp.complexfloating): + all_imag = (start.real == 0.0) & (stop.real == 0.0) + if dpnp.any(all_imag): + start[all_imag] = start[all_imag].imag + stop[all_imag] = stop[all_imag].imag + out_sign[all_imag] = 1j + + both_negative = (dpnp.sign(start) == -1) & (dpnp.sign(stop) == -1) + if dpnp.any(both_negative): + dpnp.negative(start[both_negative], out=start[both_negative]) + dpnp.negative(stop[both_negative], out=stop[both_negative]) + dpnp.negative(out_sign[both_negative], out=out_sign[both_negative]) + + log_start = dpnp.log10(start) + log_stop = dpnp.log10(stop) + result = dpnp_logspace( + log_start, + log_stop, + num=num, + endpoint=endpoint, + base=10.0, + dtype=dtype, + usm_type=_usm_type, + sycl_queue=sycl_queue_normalized, + ) + + if num > 0: + result[0] = start + if num > 1 and endpoint: + result[-1] = stop + + result = out_sign * result + + if axis != 0: + result = dpnp.moveaxis(result, 0, axis) + + return result.astype(dtype, copy=False) + + +def dpnp_linspace( + start, + stop, + num, + dtype=None, + device=None, + usm_type=None, + sycl_queue=None, + endpoint=True, + retstep=False, + axis=0, +): + usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations([start, stop]) + + if sycl_queue is None and device is None: + sycl_queue = sycl_queue_alloc + sycl_queue_normalized = dpnp.get_normalized_queue_device( + sycl_queue=sycl_queue, device=device + ) + + if usm_type is None: + _usm_type = "device" if usm_type_alloc is None else usm_type_alloc + else: + _usm_type = usm_type + + if not hasattr(start, "dtype") and not dpnp.isscalar(start): + start = dpnp.asarray( + start, usm_type=_usm_type, sycl_queue=sycl_queue_normalized + ) + if not hasattr(stop, "dtype") and not dpnp.isscalar(stop): + stop = dpnp.asarray( + stop, usm_type=_usm_type, sycl_queue=sycl_queue_normalized + ) + + dt = numpy.result_type(start, stop, float(num)) + dt = utils.map_dtype_to_device(dt, sycl_queue_normalized.sycl_device) + if dtype is None: + dtype = dt + + num = operator.index(num) + if num < 0: + raise ValueError("Number of points must be non-negative") + step_num = (num - 1) if endpoint else num + + step_nan = False + if step_num == 0: + step_nan = True + step = dpnp.nan + + if dpnp.isscalar(start) and dpnp.isscalar(stop): + # Call linspace() function for scalars. + res = dpnp_container.linspace( + start, + stop, + num, + dtype=dt, + usm_type=_usm_type, + sycl_queue=sycl_queue_normalized, + endpoint=endpoint, + ) + if retstep is True and step_nan is False: + step = (stop - start) / step_num + else: + _start = dpnp.asarray( + start, + dtype=dt, + usm_type=_usm_type, + sycl_queue=sycl_queue_normalized, + ) + _stop = dpnp.asarray( + stop, dtype=dt, usm_type=_usm_type, sycl_queue=sycl_queue_normalized + ) + + res = dpnp_container.arange( + 0, + stop=num, + step=1, + dtype=dt, + usm_type=_usm_type, + sycl_queue=sycl_queue_normalized, + ) + + if step_nan is False: + step = (_stop - _start) / step_num + res = res.reshape((-1,) + (1,) * step.ndim) + res = res * step + _start + + if endpoint and num > 1: + res[-1] = dpnp_container.full(step.shape, _stop) + + if axis != 0: + res = dpnp.moveaxis(res, 0, axis) + + if numpy.issubdtype(dtype, dpnp.integer): + dpnp.floor(res, out=res) + + res = res.astype(dtype, copy=False) + + if retstep is True: + if dpnp.isscalar(step): + step = dpnp.asarray( + step, usm_type=res.usm_type, sycl_queue=res.sycl_queue + ) + return (res, step) + + return res + + +def dpnp_logspace( + start, + stop, + num=50, + device=None, + usm_type=None, + sycl_queue=None, + endpoint=True, + base=10.0, + dtype=None, + axis=0, +): + if not dpnp.isscalar(base): + usm_type_alloc, sycl_queue_alloc = utils.get_usm_allocations( + [start, stop, base] + ) + + if sycl_queue is None and device is None: + sycl_queue = sycl_queue_alloc + sycl_queue = dpnp.get_normalized_queue_device( + sycl_queue=sycl_queue, device=device + ) + + if usm_type is None: + usm_type = "device" if usm_type_alloc is None else usm_type_alloc + else: + usm_type = usm_type + start = dpnp.asarray(start, usm_type=usm_type, sycl_queue=sycl_queue) + stop = dpnp.asarray(stop, usm_type=usm_type, sycl_queue=sycl_queue) + base = dpnp.asarray(base, usm_type=usm_type, sycl_queue=sycl_queue) + [start, stop, base] = dpnp.broadcast_arrays(start, stop, base) + base = dpnp.expand_dims(base, axis=axis) + + res = dpnp_linspace( + start, + stop, + num=num, + device=device, + usm_type=usm_type, + sycl_queue=sycl_queue, + endpoint=endpoint, + axis=axis, + ) + + if dtype is None: + return dpnp.power(base, res) + return dpnp.power(base, res).astype(dtype, copy=False) diff --git a/dpnp/dpnp_iface_arraycreation.py b/dpnp/dpnp_iface_arraycreation.py index 2a8d80fc389..0ed1187cb1d 100644 --- a/dpnp/dpnp_iface_arraycreation.py +++ b/dpnp/dpnp_iface_arraycreation.py @@ -50,6 +50,12 @@ from dpnp.dpnp_algo import * from dpnp.dpnp_utils import * +from .dpnp_algo.dpnp_arraycreation import ( + dpnp_geomspace, + dpnp_linspace, + dpnp_logspace, +) + __all__ = [ "arange", "array", @@ -1019,15 +1025,28 @@ def full_like( return numpy.full_like(x1, fill_value, dtype, order, subok, shape) -def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): +def geomspace( + start, + stop, + /, + num, + *, + dtype=None, + device=None, + usm_type=None, + sycl_queue=None, + endpoint=True, + axis=0, +): """ Return numbers spaced evenly on a log scale (a geometric progression). For full documentation refer to :obj:`numpy.geomspace`. - Limitations - ----------- - Parameter `axis` is supported only with default value ``0``. + Returns + ------- + out : dpnp.ndarray + num samples, equally spaced on a log scale. See Also -------- @@ -1041,24 +1060,38 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): Examples -------- >>> import dpnp as np - >>> x = np.geomspace(1, 1000, num=4) - >>> [i for i in x] - [1.0, 10.0, 100.0, 1000.0] - >>> x2 = np.geomspace(1, 1000, num=4, endpoint=False) - >>> [i for i in x2] - [1.0, 5.62341325, 31.6227766, 177.827941] - - """ - - if not use_origin_backend(): - if axis != 0: - pass - else: - return dpnp_geomspace( - start, stop, num, endpoint, dtype, axis - ).get_pyobj() - - return call_origin(numpy.geomspace, start, stop, num, endpoint, dtype, axis) + >>> np.geomspace(1, 1000, num=4) + array([ 1., 10., 100., 1000.]) + >>> np.geomspace(1, 1000, num=3, endpoint=False) + array([ 1., 10., 100.]) + >>> np.geomspace(1, 1000, num=4, endpoint=False) + array([ 1. , 5.62341325, 31.6227766 , 177.827941 ]) + >>> np.geomspace(1, 256, num=9) + array([ 1., 2., 4., 8., 16., 32., 64., 128., 256.]) + + >>> np.geomspace(1, 256, num=9, dtype=int) + array([ 1, 2, 4, 7, 16, 32, 63, 127, 256]) + >>> np.around(np.geomspace(1, 256, num=9)).astype(int) + array([ 1, 2, 4, 8, 16, 32, 64, 128, 256]) + + >>> np.geomspace(1000, 1, num=4) + array([1000., 100., 10., 1.]) + >>> np.geomspace(-1000, -1, num=4) + array([-1000., -100., -10., -1.]) + + """ + + return dpnp_geomspace( + start, + stop, + num, + dtype=dtype, + device=device, + usm_type=usm_type, + sycl_queue=sycl_queue, + endpoint=endpoint, + axis=axis, + ) def identity( @@ -1133,11 +1166,15 @@ def linspace( For full documentation refer to :obj:`numpy.linspace`. - Limitations - ----------- - Parameter `axis` is supported only with default value ``0``. - Parameter `retstep` is supported only with default value ``False``. - Otherwise the function will be executed sequentially on CPU. + Returns + ------- + out : dpnp.ndarray + There are num equally spaced samples in the closed interval + [`start`, `stop`] or the half-open interval [`start`, `stop`) + (depending on whether `endpoint` is ``True`` or ``False``). + step : float, optional + Only returned if `retstep` is ``True``. + Size of spacing between samples. See Also -------- @@ -1151,36 +1188,28 @@ def linspace( Examples -------- >>> import dpnp as np - >>> x = np.linspace(2.0, 3.0, num=5) - >>> [i for i in x] - [2.0, 2.25, 2.5, 2.75, 3.0] - >>> x2 = np.linspace(2.0, 3.0, num=5, endpoint=False) - >>> [i for i in x2] - [2.0, 2.2, 2.4, 2.6, 2.8] - >>> x3, step = np.linspace(2.0, 3.0, num=5, retstep=True) - >>> [i for i in x3], step - ([2.0, 2.25, 2.5, 2.75, 3.0], 0.25) + >>> np.linspace(2.0, 3.0, num=5) + array([2. , 2.25, 2.5 , 2.75, 3. ]) - """ + >>> np.linspace(2.0, 3.0, num=5, endpoint=False) + array([2. , 2.2, 2.4, 2.6, 2.8]) - if retstep is not False: - pass - elif axis != 0: - pass - else: - return dpnp_linspace( - start, - stop, - num, - dtype=dtype, - device=device, - usm_type=usm_type, - sycl_queue=sycl_queue, - endpoint=endpoint, - ) + >>> np.linspace(2.0, 3.0, num=5, retstep=True) + (array([2. , 2.25, 2.5 , 2.75, 3. ]), array(0.25)) - return call_origin( - numpy.linspace, start, stop, num, endpoint, retstep, dtype, axis + """ + + return dpnp_linspace( + start, + stop, + num, + dtype=dtype, + device=device, + usm_type=usm_type, + sycl_queue=sycl_queue, + endpoint=endpoint, + retstep=retstep, + axis=axis, ) @@ -1210,15 +1239,29 @@ def loadtxt(fname, **kwargs): return call_origin(numpy.loadtxt, fname, **kwargs) -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): +def logspace( + start, + stop, + /, + num=50, + *, + device=None, + usm_type=None, + sycl_queue=None, + endpoint=True, + base=10.0, + dtype=None, + axis=0, +): """ Return numbers spaced evenly on a log scale. For full documentation refer to :obj:`numpy.logspace`. - Limitations - ----------- - Parameter `axis` is supported only with default value ``0``. + Returns + ------- + out: dpnp.ndarray + num samples, equally spaced on a log scale. See Also -------- @@ -1234,28 +1277,32 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): Examples -------- >>> import dpnp as np - >>> x = np.logspace(2.0, 3.0, num=4) - >>> [i for i in x] - [100.0, 215.443469, 464.15888336, 1000.0] - >>> x2 = np.logspace(2.0, 3.0, num=4, endpoint=False) - >>> [i for i in x2] - [100.0, 177.827941, 316.22776602, 562.34132519] - >>> x3 = np.logspace(2.0, 3.0, num=4, base=2.0) - >>> [i for i in x3] - [4.0, 5.0396842, 6.34960421, 8.0] + >>> np.logspace(2.0, 3.0, num=4) + array([ 100. , 215.443469 , 464.15888336, 1000. ]) - """ + >>> np.logspace(2.0, 3.0, num=4, endpoint=False) + array([100. , 177.827941 , 316.22776602, 562.34132519]) - if not use_origin_backend(): - if axis != 0: - checker_throw_value_error("linspace", "axis", axis, 0) + >>> np.logspace(2.0, 3.0, num=4, base=2.0) + array([4. , 5.0396842 , 6.34960421, 8. ]) - return dpnp_logspace( - start, stop, num, endpoint, base, dtype, axis - ).get_pyobj() + >>> np.logspace(2.0, 3.0, num=4, base=[2.0, 3.0], axis=-1) + array([[ 4. , 5.0396842 , 6.34960421, 8. ], + [ 9. , 12.98024613, 18.72075441, 27. ]]) - return call_origin( - numpy.logspace, start, stop, num, endpoint, base, dtype, axis + """ + + return dpnp_logspace( + start, + stop, + num=num, + device=device, + usm_type=usm_type, + sycl_queue=sycl_queue, + endpoint=endpoint, + base=base, + dtype=dtype, + axis=axis, ) diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 48490d92f38..e61f7497d97 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -200,10 +200,6 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3 tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid3 tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid4 tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid5 -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop_axis1 -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_one_num_no_endopoint_with_retstep -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_with_retstep -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_zero_num_no_endopoint_with_retstep tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1 tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1 diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 446ca789f25..a8d198b7733 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -278,11 +278,6 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid4 tests/third_party/cupy/creation_tests/test_ranges.py::TestOgrid::test_ogrid5 tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_arange_negative_size tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_arange_no_dtype_int -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_array_start_stop_axis1 -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_one_num_no_endopoint_with_retstep -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_with_retstep -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_zero_num_no_endopoint_with_retstep -tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_logspace_zero_num tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2 tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2 diff --git a/tests/test_arraycreation.py b/tests/test_arraycreation.py index 9c7fd6bf060..7c674d265da 100644 --- a/tests/test_arraycreation.py +++ b/tests/test_arraycreation.py @@ -201,27 +201,6 @@ def test_fromstring(dtype): assert_array_equal(func(dpnp), func(numpy)) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("dtype", get_all_dtypes()) -@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27]) -@pytest.mark.parametrize("endpoint", [True, False]) -def test_geomspace(dtype, num, endpoint): - start = 2 - stop = 256 - - func = lambda xp: xp.geomspace(start, stop, num, endpoint, dtype) - - np_res = func(numpy) - dpnp_res = func(dpnp) - - # Note that the above may not produce exact integers: - # (c) https://numpy.org/doc/stable/reference/generated/numpy.geomspace.html - if dtype in [numpy.int64, numpy.int32]: - assert_allclose(dpnp_res, np_res, atol=1) - else: - assert_allclose(dpnp_res, np_res) - - @pytest.mark.parametrize("n", [0, 1, 4], ids=["0", "1", "4"]) @pytest.mark.parametrize("dtype", get_all_dtypes()) def test_identity(n, dtype): @@ -629,25 +608,35 @@ def test_dpctl_tensor_input(func, args): ) @pytest.mark.parametrize( "num", - [5, numpy.array(10), dpnp.array(17), dpt.asarray(100)], - ids=["5", "numpy.array(10)", "dpnp.array(17)", "dpt.asarray(100)"], + [1, 5, numpy.array(10), dpnp.array(17), dpt.asarray(100)], + ids=["1", "5", "numpy.array(10)", "dpnp.array(17)", "dpt.asarray(100)"], ) @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_float16=False) ) -def test_linspace(start, stop, num, dtype): - func = lambda xp: xp.linspace(start, stop, num, dtype=dtype) +@pytest.mark.parametrize("retstep", [True, False], ids=["True", "False"]) +def test_linspace(start, stop, num, dtype, retstep): + res_np = numpy.linspace(start, stop, num, dtype=dtype, retstep=retstep) + res_dp = dpnp.linspace(start, stop, num, dtype=dtype, retstep=retstep) + + if retstep: + [res_np, step_np] = res_np + [res_dp, step_dp] = res_dp + assert_allclose(step_np, step_dp) if numpy.issubdtype(dtype, dpnp.integer): - assert_allclose(func(numpy), func(dpnp), rtol=1) + assert_allclose(res_np, res_dp, rtol=1) else: if dtype is None and not has_support_aspect64(): dtype = dpnp.float32 - assert_allclose( - func(numpy), func(dpnp), rtol=1e-06, atol=numpy.finfo(dtype).eps - ) + assert_allclose(res_np, res_dp, rtol=1e-06, atol=dpnp.finfo(dtype).eps) +@pytest.mark.parametrize( + "func", + ["geomspace", "linspace", "logspace"], + ids=["geomspace", "linspace", "logspace"], +) @pytest.mark.parametrize( "start_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32], @@ -658,10 +647,10 @@ def test_linspace(start, stop, num, dtype): [numpy.float64, numpy.float32, numpy.int64, numpy.int32], ids=["float64", "float32", "int64", "int32"], ) -def test_linspace_dtype(start_dtype, stop_dtype): +def test_space_numpy_dtype(func, start_dtype, stop_dtype): start = numpy.array([1, 2, 3], dtype=start_dtype) stop = numpy.array([11, 7, -2], dtype=stop_dtype) - dpnp.linspace(start, stop, 10) + getattr(dpnp, func)(start, stop, 10) @pytest.mark.parametrize( @@ -691,7 +680,28 @@ def test_linspace_arrays(start, stop): def test_linspace_complex(): func = lambda xp: xp.linspace(0, 3 + 2j, num=1000) - assert_allclose(func(numpy), func(dpnp)) + assert_allclose(func(dpnp), func(numpy)) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_linspace_axis(axis): + func = lambda xp: xp.linspace([2, 3], [20, 15], num=10, axis=axis) + assert_allclose(func(dpnp), func(numpy)) + + +def test_linspace_step_nan(): + func = lambda xp: xp.linspace(1, 2, num=0, endpoint=False) + assert_allclose(func(dpnp), func(numpy)) + + +@pytest.mark.parametrize("start", [1, [1, 1]]) +@pytest.mark.parametrize("stop", [10, [10 + 10]]) +def test_linspace_retstep(start, stop): + func = lambda xp: xp.linspace(start, stop, num=10, retstep=True) + np_res = func(numpy) + dpnp_res = func(dpnp) + assert_allclose(dpnp_res[0], np_res[0]) + assert_allclose(dpnp_res[1], np_res[1]) @pytest.mark.parametrize( @@ -716,3 +726,100 @@ def test_set_shape(shape): da.shape = shape assert_array_equal(na, da) + + +def test_geomspace_zero_error(): + with pytest.raises(ValueError): + dpnp.geomspace(0, 5, 3) + dpnp.geomspace(2, 0, 3) + dpnp.geomspace(0, 0, 3) + + +def test_space_num_error(): + with pytest.raises(ValueError): + dpnp.linspace(2, 5, -3) + dpnp.geomspace(2, 5, -3) + dpnp.logspace(2, 5, -3) + dpnp.linspace([2, 3], 5, -3) + dpnp.geomspace([2, 3], 5, -3) + dpnp.logspace([2, 3], 5, -3) + + +@pytest.mark.parametrize("sign", [-1, 1]) +@pytest.mark.parametrize("dtype", get_all_dtypes()) +@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27]) +@pytest.mark.parametrize("endpoint", [True, False]) +def test_geomspace(sign, dtype, num, endpoint): + start = 2 * sign + stop = 256 * sign + + func = lambda xp: xp.geomspace( + start, stop, num, endpoint=endpoint, dtype=dtype + ) + + np_res = func(numpy) + dpnp_res = func(dpnp) + + if dtype in [numpy.int64, numpy.int32]: + assert_allclose(dpnp_res, np_res, rtol=1) + else: + assert_allclose(dpnp_res, np_res, rtol=1e-04) + + +@pytest.mark.usefixtures("allow_fall_back_on_numpy") +@pytest.mark.parametrize("start", [1j, 1 + 1j]) +@pytest.mark.parametrize("stop", [10j, 10 + 10j]) +# dpnp.sign raise numpy fall back for complex dtype +def test_geomspace_complex(start, stop): + func = lambda xp: xp.geomspace(start, stop, num=10) + np_res = func(numpy) + dpnp_res = func(dpnp) + assert_allclose(dpnp_res, np_res, rtol=1e-04) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_geomspace_axis(axis): + func = lambda xp: xp.geomspace([2, 3], [20, 15], num=10, axis=axis) + np_res = func(numpy) + dpnp_res = func(dpnp) + assert_allclose(dpnp_res, np_res, rtol=1e-04) + + +def test_geomspace_num0(): + func = lambda xp: xp.geomspace(1, 10, num=0, endpoint=False) + np_res = func(numpy) + dpnp_res = func(dpnp) + assert_allclose(dpnp_res, np_res, rtol=1e-04) + + +@pytest.mark.parametrize("dtype", get_all_dtypes()) +@pytest.mark.parametrize("num", [2, 4, 8, 3, 9, 27]) +@pytest.mark.parametrize("endpoint", [True, False]) +def test_logspace(dtype, num, endpoint): + start = 2 + stop = 5 + base = 2 + + func = lambda xp: xp.logspace( + start, stop, num, endpoint=endpoint, dtype=dtype, base=base + ) + + np_res = func(numpy) + dpnp_res = func(dpnp) + + if dtype in [numpy.int64, numpy.int32]: + assert_allclose(dpnp_res, np_res, rtol=1) + else: + assert_allclose(dpnp_res, np_res, rtol=1e-04) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_logspace_axis(axis): + if numpy.lib.NumpyVersion(numpy.__version__) < "1.25.0": + pytest.skip( + "numpy.logspace supports a non-scalar base argument since 1.25.0" + ) + func = lambda xp: xp.logspace( + [2, 3], [20, 15], num=2, base=[[1, 3], [5, 7]], axis=axis + ) + assert_allclose(func(dpnp), func(numpy)) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 3c131a6462f..2a4a814b6f7 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -85,8 +85,10 @@ def vvsort(val, vec, size, xp): pytest.param("arange", [-25.7], {"stop": 10**8, "step": 15}), pytest.param("full", [(2, 2)], {"fill_value": 5}), pytest.param("eye", [4, 2], {}), + pytest.param("geomspace", [1, 4, 8], {}), pytest.param("identity", [4], {}), pytest.param("linspace", [0, 4, 8], {}), + pytest.param("logspace", [0, 4, 8], {}), pytest.param("ones", [(2, 2)], {}), pytest.param("tri", [3, 5, 2], {}), pytest.param("zeros", [(2, 2)], {}), @@ -140,12 +142,16 @@ def test_empty_like(device_x, device_y): "func, args, kwargs", [ pytest.param("full_like", ["x0"], {"fill_value": 5}), + pytest.param("geomspace", ["x0[0:3]", "8", "4"], {}), + pytest.param("geomspace", ["1", "x0[2:4]", "4"], {}), + pytest.param("linspace", ["x0[0:2]", "8", "4"], {}), + pytest.param("linspace", ["0", "x0[2:4]", "4"], {}), + pytest.param("logspace", ["x0[0:2]", "8", "4"], {}), + pytest.param("logspace", ["0", "x0[2:4]", "4"], {}), pytest.param("ones_like", ["x0"], {}), - pytest.param("zeros_like", ["x0"], {}), pytest.param("tril", ["x0.reshape((2,2))"], {}), pytest.param("triu", ["x0.reshape((2,2))"], {}), - pytest.param("linspace", ["x0", "4", "4"], {}), - pytest.param("linspace", ["1", "x0", "4"], {}), + pytest.param("zeros_like", ["x0"], {}), ], ) @pytest.mark.parametrize( @@ -162,7 +168,27 @@ def test_array_creation_follow_device(func, args, kwargs, device): dpnp_args = [eval(val, {"x0": x}) for val in args] y = getattr(dpnp, func)(*dpnp_args, **kwargs) - assert_allclose(y_orig, y) + assert_allclose(y_orig, y, rtol=1e-04) + assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) + + +@pytest.mark.skipif( + numpy.lib.NumpyVersion(numpy.__version__) < "1.25.0", + reason="numpy.logspace supports a non-scalar base argument since 1.25.0", +) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_array_creation_follow_device_logspace_base(device): + x_orig = numpy.array([1, 2, 3, 4]) + y_orig = numpy.logspace(0, 8, 4, base=x_orig[1:3]) + + x = dpnp.array([1, 2, 3, 4], device=device) + y = dpnp.logspace(0, 8, 4, base=x[1:3]) + + assert_allclose(y_orig, y, rtol=1e-04) assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 3060edd4bea..99a39acae88 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -140,13 +140,17 @@ def test_coerced_usm_types_power(usm_type_x, usm_type_y): @pytest.mark.parametrize( "func, args", [ + pytest.param("empty_like", ["x0"]), pytest.param("full", ["10", "x0[3]"]), pytest.param("full_like", ["x0", "4"]), - pytest.param("zeros_like", ["x0"]), - pytest.param("ones_like", ["x0"]), - pytest.param("empty_like", ["x0"]), - pytest.param("linspace", ["x0[0:2]", "4", "4"]), + pytest.param("geomspace", ["x0[0:3]", "8", "4"]), + pytest.param("geomspace", ["1", "x0[3:5]", "4"]), + pytest.param("linspace", ["x0[0:2]", "8", "4"]), pytest.param("linspace", ["0", "x0[3:5]", "4"]), + pytest.param("logspace", ["x0[0:2]", "8", "4"]), + pytest.param("logspace", ["0", "x0[3:5]", "4"]), + pytest.param("ones_like", ["x0"]), + pytest.param("zeros_like", ["x0"]), ], ) @pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) @@ -168,8 +172,10 @@ def test_array_creation_from_an_array(func, args, usm_type_x, usm_type_y): pytest.param("arange", [-25.7], {"stop": 10**8, "step": 15}), pytest.param("full", [(2, 2)], {"fill_value": 5}), pytest.param("eye", [4, 2], {}), + pytest.param("geomspace", [1, 4, 8], {}), pytest.param("identity", [4], {}), pytest.param("linspace", [0, 4, 8], {}), + pytest.param("logspace", [0, 4, 8], {}), pytest.param("ones", [(2, 2)], {}), pytest.param("tri", [3, 5, 2], {}), pytest.param("zeros", [(2, 2)], {}), @@ -189,6 +195,18 @@ def test_array_creation_from_scratch(func, arg, kwargs, usm_type): assert dpnp_array.usm_type == usm_type +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) +def test_logspace_base(usm_type_x, usm_type_y): + x0 = dp.full(10, 2, usm_type=usm_type_x) + + x = dp.logspace([2, 2], 8, 4, base=x0[3:5]) + y = dp.logspace([2, 2], 8, 4, base=x0[3:5], usm_type=usm_type_y) + + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + + @pytest.mark.parametrize( "func", [ diff --git a/tests/third_party/cupy/creation_tests/test_ranges.py b/tests/third_party/cupy/creation_tests/test_ranges.py index be2f113a318..623adc409b7 100644 --- a/tests/third_party/cupy/creation_tests/test_ranges.py +++ b/tests/third_party/cupy/creation_tests/test_ranges.py @@ -1,3 +1,4 @@ +import functools import math import sys import unittest @@ -10,6 +11,27 @@ from tests.third_party.cupy import testing +def skip_int_equality_before_numpy_1_20(names=("dtype",)): + """Require numpy/numpy#16841 or skip the equality check.""" + + def decorator(wrapped): + if numpy.lib.NumpyVersion(numpy.__version__) >= "1.20.0": + return wrapped + + @functools.wraps(wrapped) + def wrapper(self, *args, **kwargs): + xp = kwargs["xp"] + dtypes = [kwargs[name] for name in names] + ret = wrapped(self, *args, **kwargs) + if any(numpy.issubdtype(dtype, numpy.integer) for dtype in dtypes): + ret = xp.zeros_like(ret) + return ret + + return wrapper + + return decorator + + @testing.gpu class TestRanges(unittest.TestCase): @testing.for_all_dtypes(no_bool=True) @@ -79,6 +101,14 @@ def test_linspace(self, xp, dtype): def test_linspace2(self, xp, dtype): return xp.linspace(10, 0, 5, dtype=dtype) + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + @skip_int_equality_before_numpy_1_20() + def test_linspace3(self, xp, dtype): + if xp.dtype(dtype).kind == "u": + pytest.skip() + return xp.linspace(-10, 8, 9, dtype=dtype) + @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_array_equal() def test_linspace_zero_num(self, xp, dtype): @@ -192,7 +222,7 @@ def test_linspace_mixed_start_stop2(self, xp, dtype_range, dtype_out): @testing.for_all_dtypes_combination( names=("dtype_range", "dtype_out"), no_bool=True, no_complex=True ) - @testing.numpy_cupy_array_equal() + @testing.numpy_cupy_allclose(rtol=1e-04) def test_linspace_array_start_stop_axis1(self, xp, dtype_range, dtype_out): start = xp.array([0, 120], dtype=dtype_range) stop = xp.array([100, 0], dtype=dtype_range) @@ -214,13 +244,11 @@ def test_linspace_start_stop_list(self, xp, dtype): stop = [100, 16] return xp.linspace(start, stop, num=50, dtype=dtype) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_allclose() def test_logspace(self, xp, dtype): return xp.logspace(0, 2, 5, dtype=dtype) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_allclose() def test_logspace2(self, xp, dtype): @@ -231,29 +259,24 @@ def test_logspace2(self, xp, dtype): def test_logspace_zero_num(self, xp, dtype): return xp.logspace(0, 2, 0, dtype=dtype) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_allclose() def test_logspace_one_num(self, xp, dtype): return xp.logspace(0, 2, 1, dtype=dtype) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_allclose() def test_logspace_no_endpoint(self, xp, dtype): return xp.logspace(0, 2, 5, dtype=dtype, endpoint=False) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.numpy_cupy_allclose(rtol=1e-4, type_check=has_support_aspect64()) def test_logspace_no_dtype_int(self, xp): return xp.logspace(0, 2) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.numpy_cupy_allclose(rtol=1e-4, type_check=has_support_aspect64()) def test_logspace_no_dtype_float(self, xp): return xp.logspace(0.0, 2.0) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.numpy_cupy_allclose() def test_logspace_float_args_with_int_dtype(self, xp): return xp.logspace(0.1, 2.1, 11, dtype=int) @@ -263,12 +286,22 @@ def test_logspace_neg_num(self): with pytest.raises(ValueError): xp.logspace(0, 10, -1) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_allclose(rtol=1e-04) def test_logspace_base(self, xp, dtype): return xp.logspace(0, 2, 5, base=2.0, dtype=dtype) + # See #7946 and https://github.com/numpy/numpy/issues/24957 + @testing.with_requires("numpy>=1.16, !=1.25.*, !=1.26.*") + @testing.for_all_dtypes_combination( + names=("dtype_range", "dtype_out"), no_bool=True, no_complex=True + ) + @testing.numpy_cupy_allclose(rtol=1e-6, contiguous_check=False) + def test_logspace_array_start_stop_axis1(self, xp, dtype_range, dtype_out): + start = xp.array([0, 2], dtype=dtype_range) + stop = xp.array([2, 0], dtype=dtype_range) + return xp.logspace(start, stop, num=5, dtype=dtype_out, axis=1) + @testing.parameterize( *testing.product(