Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dpnp.fft.fftshift and dpnp.fft.ifftshift #1900

Merged
merged 8 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 90 additions & 66 deletions dpnp/fft/dpnp_iface_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,45 +366,62 @@ def fftshift(x, axes=None):
"""
Shift the zero-frequency component to the center of the spectrum.

This function swaps half-spaces for all axes listed (defaults to all).
Note that ``out[0]`` is the Nyquist component only if ``len(x)`` is even.

For full documentation refer to :obj:`numpy.fft.fftshift`.

Limitations
-----------
Parameter `x` is supported either as :class:`dpnp.ndarray`.
Parameter `axes` is unsupported.
Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`,
`dpnp.complex128` data types are supported.
Otherwise the function will be executed sequentially on CPU.
Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
Input array.
axes : {None, int, list or tuple of ints}, optional
Axes over which to shift.
Default is ``None``, which shifts all axes.

"""
Returns
-------
out : dpnp.ndarray
The shifted array.

x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False)
# TODO: enable implementation
# pylint: disable=condition-evals-to-constant
if x_desc and 0:
norm_ = Norm.backward
See Also
--------
:obj:`dpnp.fft.ifftshift` : The inverse of :obj:`dpnp.fft.fftshift`.

if axes is None:
axis_param = -1 # the most right dimension (default value)
else:
axis_param = axes
Examples
--------
>>> import dpnp as np
>>> freqs = np.fft.fftfreq(10, 0.1)
>>> freqs
array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.])
>>> np.fft.fftshift(freqs)
array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.])

Shift the zero-frequency component only along the second axis:

>>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
>>> freqs
array([[ 0., 1., 2.],
[ 3., 4., -4.],
[-3., -2., -1.]])
>>> np.fft.fftshift(freqs, axes=(1,))
array([[ 2., 0., 1.],
[-4., 3., 4.],
[-1., -3., -2.]])

if x_desc.size < 1:
pass # let fallback to handle exception
else:
input_boundarie = x_desc.shape[axis_param]
output_boundarie = input_boundarie
"""

return dpnp_fft_deprecated(
x_desc,
input_boundarie,
output_boundarie,
axis_param,
False,
norm_.value,
).get_pyobj()
dpnp.check_supported_arrays_type(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, int):
shift = x.shape[axes] // 2
else:
x_shape = x.shape
shift = [x_shape[ax] // 2 for ax in axes]

return call_origin(numpy.fft.fftshift, x, axes)
return dpnp.roll(x, shift, axes)


def hfft(x, n=None, axis=-1, norm=None):
Expand Down Expand Up @@ -620,48 +637,55 @@ def ifftshift(x, axes=None):
"""
Inverse shift the zero-frequency component to the center of the spectrum.

For full documentation refer to :obj:`numpy.fft.ifftshift`.
Although identical for even-length `x`, the functions differ by one sample
for odd-length `x`.

Limitations
-----------
Parameter `x` is supported either as :class:`dpnp.ndarray`.
Parameter `axes` is unsupported.
Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`,
`dpnp.complex128` data types are supported.
Otherwise the function will be executed sequentially on CPU.

"""
For full documentation refer to :obj:`numpy.fft.ifftshift`.

x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False)
# TODO: enable implementation
# pylint: disable=condition-evals-to-constant
if x_desc and 0:
norm_ = Norm.backward
Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
Input array.
axes : {None, int, list or tuple of ints}, optional
Axes over which to calculate.
Defaults to ``None``, which shifts all axes.

if axes is None:
axis_param = -1 # the most right dimension (default value)
else:
axis_param = axes
Returns
-------
out : dpnp.ndarray
The shifted array.

input_boundarie = x_desc.shape[axis_param]
See Also
--------
:obj:`dpnp.fft.fftshift` : Shift zero-frequency component to the center
of the spectrum.

if x_desc.size < 1:
pass # let fallback to handle exception
elif input_boundarie < 1:
pass # let fallback to handle exception
else:
output_boundarie = input_boundarie
Examples
--------
>>> import dpnp as np
>>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
>>> freqs
array([[ 0., 1., 2.],
[ 3., 4., -4.],
[-3., -2., -1.]])
>>> np.fft.ifftshift(np.fft.fftshift(freqs))
array([[ 0., 1., 2.],
[ 3., 4., -4.],
[-3., -2., -1.]])

return dpnp_fft_deprecated(
x_desc,
input_boundarie,
output_boundarie,
axis_param,
True,
norm_.value,
).get_pyobj()
"""

return call_origin(numpy.fft.ifftshift, x, axes)
dpnp.check_supported_arrays_type(x)
if axes is None:
axes = tuple(range(x.ndim))
shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, int):
shift = -(x.shape[axes] // 2)
else:
x_shape = x.shape
shift = [-(x_shape[ax] // 2) for ax in axes]

return dpnp.roll(x, shift, axes)


def ihfft(x, n=None, axis=-1, norm=None):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,14 @@ def test_error(self, func):

# d should be an scalar
assert_raises(ValueError, getattr(dpnp.fft, func), 10, (2,))


class TestFftshift:
@pytest.mark.parametrize("func", ["fftshift", "ifftshift"])
@pytest.mark.parametrize("axes", [None, 1, (0, 1)])
def test_fftshift(self, func, axes):
x = dpnp.arange(12).reshape(3, 4)
x_np = x.asnumpy()
expected = getattr(dpnp.fft, func)(x, axes=axes)
result = getattr(numpy.fft, func)(x_np, axes=axes)
assert_dtype_allclose(expected, result)
21 changes: 21 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,27 @@ def test_fftfreq(func, device):
assert result.sycl_device == device


@pytest.mark.parametrize("func", ["fftshift", "ifftshift"])
@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_fftshift(func, device):
dpnp_data = dpnp.fft.fftfreq(10, 0.5, device=device)
data = dpnp_data.asnumpy()

expected = getattr(numpy.fft, func)(data)
result = getattr(dpnp.fft, func)(dpnp_data)

assert_dtype_allclose(result, expected)

expected_queue = dpnp_data.get_array().sycl_queue
result_queue = result.get_array().sycl_queue

assert_sycl_queue_equal(result_queue, expected_queue)


@pytest.mark.parametrize(
"data, is_empty",
[
Expand Down
17 changes: 14 additions & 3 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,17 @@ def test_eigenvalue(func, shape, usm_type):
assert a.usm_type == dp_val.usm_type


@pytest.mark.parametrize("func", ["fft", "ifft"])
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_fft(func, usm_type):

dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64)
result = getattr(dp.fft, func)(dpnp_data)

assert dpnp_data.usm_type == usm_type
assert result.usm_type == usm_type


@pytest.mark.parametrize("func", ["fftfreq", "rfftfreq"])
@pytest.mark.parametrize("usm_type", list_of_usm_types + [None])
def test_fftfreq(func, usm_type):
Expand All @@ -947,10 +958,10 @@ def test_fftfreq(func, usm_type):
assert result.usm_type == usm_type


@pytest.mark.parametrize("func", ["fft", "ifft"])
@pytest.mark.parametrize("func", ["fftshift", "ifftshift"])
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_fft(func, usm_type):
dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64)
def test_fftshift(func, usm_type):
dpnp_data = dp.fft.fftfreq(10, 0.5, usm_type=usm_type)
result = getattr(dp.fft, func)(dpnp_data)

assert dpnp_data.usm_type == usm_type
Expand Down
1 change: 0 additions & 1 deletion tests/third_party/cupy/fft_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def test_rfftfreq(self, xp):
{"shape": (10, 10), "axes": 0},
{"shape": (10, 10), "axes": (0, 1)},
)
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
class TestFftshift:
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(
Expand Down
Loading