From 1276abf07d16200868021012fa1ed90db9daf5bc Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Fri, 18 Mar 2022 08:23:42 +0000 Subject: [PATCH] Update unitests of stft op. --- paddle/fluid/operators/overlap_add_op.cc | 2 +- python/paddle/fluid/tests/unittests/test_stft_op.py | 2 +- python/paddle/signal.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/overlap_add_op.cc b/paddle/fluid/operators/overlap_add_op.cc index 5935dc56d0ae3..0e6f0f8422106 100644 --- a/paddle/fluid/operators/overlap_add_op.cc +++ b/paddle/fluid/operators/overlap_add_op.cc @@ -81,7 +81,7 @@ class OverlapAddOp : public framework::OperatorWithKernel { hop_length, frame_length)); } - if (n_frames == -1 && frame_length == -1) { + if (n_frames == -1) { seq_length = -1; } else { seq_length = (n_frames - 1) * hop_length + frame_length; diff --git a/python/paddle/fluid/tests/unittests/test_stft_op.py b/python/paddle/fluid/tests/unittests/test_stft_op.py index 486c6c79a9250..64b8084a1651f 100644 --- a/python/paddle/fluid/tests/unittests/test_stft_op.py +++ b/python/paddle/fluid/tests/unittests/test_stft_op.py @@ -59,7 +59,7 @@ def setUp(self): self.outputs = {'Out': stft_np(x=self.inputs['X'], **self.attrs)} def initTestCase(self): - input_shape = (2, 1600) + input_shape = (2, 100) input_type = 'float64' attrs = { 'n_fft': 50, diff --git a/python/paddle/signal.py b/python/paddle/signal.py index df0781ed74e80..f5b225bc6da2d 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -546,7 +546,7 @@ def istft(x, window_envelop = overlap_add( x=paddle.tile( - x=paddle.multiply(window, window), + x=paddle.multiply(window, window).unsqueeze(0), repeat_times=[n_frames, 1]).transpose( perm=[1, 0]), # (n_fft, num_frames) hop_length=hop_length, @@ -566,7 +566,7 @@ def istft(x, window_envelop = window_envelop[start:start + length] # Check whether the Nonzero Overlap Add (NOLA) constraint is met. - if window_envelop.abs().min().item() < 1e-11: + if in_dygraph_mode() and window_envelop.abs().min().item() < 1e-11: raise ValueError( 'Abort istft because Nonzero Overlap Add (NOLA) condition failed. For more information about NOLA constraint please see `scipy.signal.check_NOLA`(https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.check_NOLA.html).' )