diff --git a/dpnp/fft/dpnp_iface_fft.py b/dpnp/fft/dpnp_iface_fft.py index 0fbd15d5cf0..4bf26b1eba7 100644 --- a/dpnp/fft/dpnp_iface_fft.py +++ b/dpnp/fft/dpnp_iface_fft.py @@ -366,45 +366,62 @@ def fftshift(x, axes=None): """ Shift the zero-frequency component to the center of the spectrum. + This function swaps half-spaces for all axes listed (defaults to all). + Note that ``out[0]`` is the Nyquist component only if ``len(x)`` is even. + For full documentation refer to :obj:`numpy.fft.fftshift`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `axes` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + Input array. + axes : {None, int, list or tuple of ints}, optional + Axes over which to shift. + Default is ``None``, which shifts all axes. - """ + Returns + ------- + out : dpnp.ndarray + The shifted array. - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - # TODO: enable implementation - # pylint: disable=condition-evals-to-constant - if x_desc and 0: - norm_ = Norm.backward + See Also + -------- + :obj:`dpnp.fft.ifftshift` : The inverse of :obj:`dpnp.fft.fftshift`. - if axes is None: - axis_param = -1 # the most right dimension (default value) - else: - axis_param = axes + Examples + -------- + >>> import dpnp as np + >>> freqs = np.fft.fftfreq(10, 0.1) + >>> freqs + array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.]) + >>> np.fft.fftshift(freqs) + array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.]) + + Shift the zero-frequency component only along the second axis: + + >>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3) + >>> freqs + array([[ 0., 1., 2.], + [ 3., 4., -4.], + [-3., -2., -1.]]) + >>> np.fft.fftshift(freqs, axes=(1,)) + array([[ 2., 0., 1.], + [-4., 3., 4.], + [-1., -3., -2.]]) - if x_desc.size < 1: - pass # let fallback to handle exception - else: - input_boundarie = x_desc.shape[axis_param] - output_boundarie = input_boundarie + """ - return dpnp_fft_deprecated( - x_desc, - input_boundarie, - output_boundarie, - axis_param, - False, - norm_.value, - ).get_pyobj() + dpnp.check_supported_arrays_type(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [dim // 2 for dim in x.shape] + elif isinstance(axes, int): + shift = x.shape[axes] // 2 + else: + x_shape = x.shape + shift = [x_shape[ax] // 2 for ax in axes] - return call_origin(numpy.fft.fftshift, x, axes) + return dpnp.roll(x, shift, axes) def hfft(x, n=None, axis=-1, norm=None): @@ -620,48 +637,55 @@ def ifftshift(x, axes=None): """ Inverse shift the zero-frequency component to the center of the spectrum. - For full documentation refer to :obj:`numpy.fft.ifftshift`. + Although identical for even-length `x`, the functions differ by one sample + for odd-length `x`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `axes` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. - - """ + For full documentation refer to :obj:`numpy.fft.ifftshift`. - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - # TODO: enable implementation - # pylint: disable=condition-evals-to-constant - if x_desc and 0: - norm_ = Norm.backward + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + Input array. + axes : {None, int, list or tuple of ints}, optional + Axes over which to calculate. + Defaults to ``None``, which shifts all axes. - if axes is None: - axis_param = -1 # the most right dimension (default value) - else: - axis_param = axes + Returns + ------- + out : dpnp.ndarray + The shifted array. - input_boundarie = x_desc.shape[axis_param] + See Also + -------- + :obj:`dpnp.fft.fftshift` : Shift zero-frequency component to the center + of the spectrum. - if x_desc.size < 1: - pass # let fallback to handle exception - elif input_boundarie < 1: - pass # let fallback to handle exception - else: - output_boundarie = input_boundarie + Examples + -------- + >>> import dpnp as np + >>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3) + >>> freqs + array([[ 0., 1., 2.], + [ 3., 4., -4.], + [-3., -2., -1.]]) + >>> np.fft.ifftshift(np.fft.fftshift(freqs)) + array([[ 0., 1., 2.], + [ 3., 4., -4.], + [-3., -2., -1.]]) - return dpnp_fft_deprecated( - x_desc, - input_boundarie, - output_boundarie, - axis_param, - True, - norm_.value, - ).get_pyobj() + """ - return call_origin(numpy.fft.ifftshift, x, axes) + dpnp.check_supported_arrays_type(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [-(dim // 2) for dim in x.shape] + elif isinstance(axes, int): + shift = -(x.shape[axes] // 2) + else: + x_shape = x.shape + shift = [-(x_shape[ax] // 2) for ax in axes] + + return dpnp.roll(x, shift, axes) def ihfft(x, n=None, axis=-1, norm=None): diff --git a/tests/test_fft.py b/tests/test_fft.py index f6818a53214..4091eb6a8f8 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -372,3 +372,14 @@ def test_error(self, func): # d should be an scalar assert_raises(ValueError, getattr(dpnp.fft, func), 10, (2,)) + + +class TestFftshift: + @pytest.mark.parametrize("func", ["fftshift", "ifftshift"]) + @pytest.mark.parametrize("axes", [None, 1, (0, 1)]) + def test_fftshift(self, func, axes): + x = dpnp.arange(12).reshape(3, 4) + x_np = x.asnumpy() + expected = getattr(dpnp.fft, func)(x, axes=axes) + result = getattr(numpy.fft, func)(x_np, axes=axes) + assert_dtype_allclose(expected, result) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 73b60d918f3..4c6a001273a 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1276,6 +1276,27 @@ def test_fftfreq(func, device): assert result.sycl_device == device +@pytest.mark.parametrize("func", ["fftshift", "ifftshift"]) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_fftshift(func, device): + dpnp_data = dpnp.fft.fftfreq(10, 0.5, device=device) + data = dpnp_data.asnumpy() + + expected = getattr(numpy.fft, func)(data) + result = getattr(dpnp.fft, func)(dpnp_data) + + assert_dtype_allclose(result, expected) + + expected_queue = dpnp_data.get_array().sycl_queue + result_queue = result.get_array().sycl_queue + + assert_sycl_queue_equal(result_queue, expected_queue) + + @pytest.mark.parametrize( "data, is_empty", [ diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index a618b99aac3..6cc8d5edd39 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -933,6 +933,17 @@ def test_eigenvalue(func, shape, usm_type): assert a.usm_type == dp_val.usm_type +@pytest.mark.parametrize("func", ["fft", "ifft"]) +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_fft(func, usm_type): + + dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64) + result = getattr(dp.fft, func)(dpnp_data) + + assert dpnp_data.usm_type == usm_type + assert result.usm_type == usm_type + + @pytest.mark.parametrize("func", ["fftfreq", "rfftfreq"]) @pytest.mark.parametrize("usm_type", list_of_usm_types + [None]) def test_fftfreq(func, usm_type): @@ -947,10 +958,10 @@ def test_fftfreq(func, usm_type): assert result.usm_type == usm_type -@pytest.mark.parametrize("func", ["fft", "ifft"]) +@pytest.mark.parametrize("func", ["fftshift", "ifftshift"]) @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) -def test_fft(func, usm_type): - dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64) +def test_fftshift(func, usm_type): + dpnp_data = dp.fft.fftfreq(10, 0.5, usm_type=usm_type) result = getattr(dp.fft, func)(dpnp_data) assert dpnp_data.usm_type == usm_type diff --git a/tests/third_party/cupy/fft_tests/test_fft.py b/tests/third_party/cupy/fft_tests/test_fft.py index 7b354cdab27..401a7bf26e9 100644 --- a/tests/third_party/cupy/fft_tests/test_fft.py +++ b/tests/third_party/cupy/fft_tests/test_fft.py @@ -380,7 +380,6 @@ def test_rfftfreq(self, xp): {"shape": (10, 10), "axes": 0}, {"shape": (10, 10), "axes": (0, 1)}, ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestFftshift: @testing.for_all_dtypes() @testing.numpy_cupy_allclose(