From 61f1c936bf726e9e0612c5ef13cd5f7ed9ea589b Mon Sep 17 00:00:00 2001 From: Mike McCann <57153404+Michael-T-McCann@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:52:30 -0800 Subject: [PATCH] Fix fractional centers in CircularConvolve (#471) * Add failing test * Fix phase ramp * Minor comment improvement * Improve RNG key use --------- Co-authored-by: Brendt Wohlberg --- scico/linop/_circconv.py | 11 +++++++++-- scico/test/linop/test_circconv.py | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index dfbd0bb86..deacce792 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -120,7 +120,6 @@ def __init__( output_dtype = snp.dtype(input_dtype) # cannot infer from h_dft because it is complex else: fft_shape = input_shape[-self.ndims :] - pad = () fft_axes = list(range(h.ndim - self.ndims, h.ndim)) self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes) output_dtype = result_type(h.dtype, input_dtype) @@ -140,7 +139,15 @@ def __init__( offset = -snp.array(self.h_center) shifts: Tuple[np.ndarray, ...] = np.ix_( *tuple( - np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s)) # type: ignore + np.select( + # see doi:10.1109/78.700979 and doi:10.1109/LSP.2012.2191280 + [np.arange(s) < s / 2, np.arange(s) == s / 2, np.arange(s) > s / 2], + [ + np.exp(-1j * k * 2 * np.pi * np.arange(s) / s), + np.cos(k * np.pi), + np.exp(1j * k * 2 * np.pi * (s - np.arange(s)) / s), + ], # type: ignore + ) for k, s in zip(offset, input_shape[-self.ndims :]) ) ) diff --git a/scico/test/linop/test_circconv.py b/scico/test/linop/test_circconv.py index 9e6cb8c1c..8eb2859fc 100644 --- a/scico/test/linop/test_circconv.py +++ b/scico/test/linop/test_circconv.py @@ -33,7 +33,6 @@ def setup_method(self, method): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) def test_eval(self, axes_shape_spec, input_dtype, jit): - x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) @@ -62,7 +61,6 @@ def test_eval(self, axes_shape_spec, input_dtype, jit): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) def test_adjoint(self, axes_shape_spec, input_dtype, jit): - x_shape, ndims, h_shape = axes_shape_spec h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) @@ -160,6 +158,21 @@ def test_center(self, center): shift = -center[0] np.testing.assert_allclose(A @ x, snp.roll(B @ x, shift), atol=1e-5) + def test_fractional_center(self): + """A fractional center should keep outputs real.""" + x, key = uniform(minval=-1, maxval=1, shape=(4, 5), key=self.key) + h, _ = uniform(minval=-1, maxval=1, shape=(2, 2), key=key) + A = CircularConvolve(h=h, input_shape=x.shape, h_center=[0.1, 2.7]) + + # taken from CircularConvolve._eval + x_dft = snp.fft.fftn(x, axes=A.x_fft_axes) + hx = snp.fft.ifftn( + A.h_dft * x_dft, + axes=A.ifft_axes, + ) + + np.testing.assert_allclose(hx, snp.real(hx)) + @pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS) @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("jit_old_op", [True, False])