Skip to content

Commit

Permalink
1. fix ifftshift(missing negative sign before shifts); (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…36834)

2. add complex data type support for paddle.shape at graph assembly.
  • Loading branch information
Feiyu Chan authored and piotrekobi committed Nov 3, 2021
1 parent 3ef3081 commit e0b4aae
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def ifftshift(x, axes=None, name=None):
# shift all axes
rank = len(x.shape)
axes = list(range(0, rank))
shifts = shape // 2
shifts = -shape // 2
elif isinstance(axes, int):
shifts = -shape[axes] // 2
else:
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11396,9 +11396,10 @@ def shape(input):
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
check_variable_and_dtype(
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape')
check_variable_and_dtype(input, 'input', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128'
], 'shape')
helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
Expand Down
24 changes: 14 additions & 10 deletions python/paddle/fluid/tests/unittests/fft/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,13 @@ def test_rfftfreq(self):


@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes', 'dtype'),
[('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64')])
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'),
])
class TestFftShift(unittest.TestCase):
def test_fftshift(self):
"""Test fftshift with norm condition
Expand All @@ -1028,11 +1030,13 @@ def test_fftshift(self):


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
])
@parameterize(
(TEST_CASE_NAME, 'x', 'axes'),
[('test_1d', np.random.randn(10), (0, ),
'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')])
class TestIfftShift(unittest.TestCase):
def test_ifftshift(self):
"""Test ifftshift with norm condition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,56 @@ def test_static_ihfftn(self):
pass


@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'x', 'axes', 'dtype'), [
('test_1d', np.random.randn(10), (0, ), 'float64'),
('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128'),
])
class TestFftShift(unittest.TestCase):
def test_fftshift(self):
"""Test fftshift with norm condition
"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = paddle.fft.fftshift(input, axes)

exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()


@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'x', 'axes'),
[('test_1d', np.random.randn(10), (0, ),
'float64'), ('test_2d', np.random.randn(10, 10), (0, 1), 'float64'),
('test_2d_with_all_axes', np.random.randn(10, 10), None, 'float64'),
('test_2d_odd_with_all_axes',
np.random.randn(5, 5) + 1j * np.random.randn(5, 5), None, 'complex128')])
class TestIfftShift(unittest.TestCase):
def test_ifftshift(self):
"""Test ifftshift with norm condition
"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = paddle.fft.ifftshift(input, axes)

exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()


if __name__ == '__main__':
unittest.main()

Expand Down

0 comments on commit e0b4aae

Please sign in to comment.