diff --git a/CHANGES.rst b/CHANGES.rst index 8a786484..e5aa718b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased) ---------------------------- • Significant changes to ``linop.xray.astra`` API. +• Rename integrated 2D X-ray transform class to + ``linop.xray.XRayTransform2D`` and add filtered back projection method + ``fbp``. • New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. @@ -17,8 +20,8 @@ Version 0.0.6 (unreleased) • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.31. -• Support ``flax`` versions 0.8.0 to 0.8.3. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33. +• Support ``flax`` versions 0.8.0 to 0.9.0. diff --git a/docs/source/references.bib b/docs/source/references.bib index 257f2428..e612e36e 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -396,6 +396,13 @@ @Article {jin-2017-unet doi = {10.1109/TIP.2017.2713099} } +@Book {kak-1988-principles, + author = {Avinash C. Kak and Malcolm Slaney}, + title = {Principles of Computerized Tomographic Imaging}, + publisher = {IEEE Press}, + year = 1988 +} + @TechReport {kamilov-2016-minimizing, author = {Ulugbek S. Kamilov}, title = {Minimizing Isotropic Total Variation without @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn pages = {3142--3155} } + @Article {zhang-2021-plug, author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index f3067e9a..83bd8462 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -68,10 +68,10 @@ def __init__( corresponds to summing along antidiagonals. x0: (x, y) position of the corner of the pixel `im[0,0]`. By default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`. - dx: Image pixel side length in x- and y-direction. Must be - set so that the width of a projected pixel is never - larger than 1.0. By default, [:math:`\sqrt{2}/2`, - :math:`\sqrt{2}/2`]. + dx: Image pixel side length in x- and y-direction (axis 0 and + 1 respectively). Must be set so that the width of a + projected pixel is never larger than 1.0. By default, + [:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`]. y0: Location of the edge of the first detector bin. By default, `-det_count / 2` det_count: Number of elements in detector. If ``None``, @@ -114,6 +114,9 @@ def __init__( self.y0 = y0 self.dy = 1.0 + self.fbp_filter: Optional[snp.Array] = None + self.fbp_mask: Optional[snp.Array] = None + super().__init__( input_shape=self.input_shape, input_dtype=np.float32, @@ -139,6 +142,71 @@ def back_project(self, y: ArrayLike) -> snp.Array: """ return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) + def fbp(self, y: ArrayLike) -> snp.Array: + r"""Compute filtered back projection (FBP) inverse of projection. + + Compute the filtered back projection inverse by filtering each + row of the sinogram with the filter defined in (61) in + :cite:`kak-1988-principles` and then back projecting. The + projection angles are assumed to be evenly spaced in + :math:`[0, \pi)`; reconstruction quality may be poor if + this assumption is violated. Poor quality reconstructions should + also be expected when `dx[0]` and `dx[1]` are not equal. + + Args: + y: Input projection, (num_angles, N). + + Returns: + FBP inverse of projection. + """ + N = y.shape[1] + + if self.fbp_filter is None: + nvec = jnp.arange(N) - (N - 1) // 2 + self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) + + if self.fbp_mask is None: + unit_sino = jnp.ones(self.output_shape, dtype=np.float32) + # Threshold is multiplied by 0.99... fudge factor to account for numerical errors + # in back projection. + self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore + + # Apply ramp filter in the frequency domain, padding to avoid + # boundary effects + h = self.fbp_filter + hf = jnp.fft.fft(h, n=2 * N - 1, axis=1) + yf = jnp.fft.fft(y, n=2 * N - 1, axis=1) + hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ + :, (N - 1) // 2 : -(N - 1) // 2 + ].real.astype(jnp.float32) + + x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore + return x + + @staticmethod + def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: + """Compute coefficients of ramp filter used in FBP. + + Compute coefficients of ramp filter used in FBP, as defined in + (61) in :cite:`kak-1988-principles`. + + Args: + x: Sampling locations at which to compute filter coefficients. + tau: Sampling rate. + + Returns: + Spatial-domain coefficients of ramp filter. + """ + # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0) + # is included to avoid division by zero warnings when x == 1 + # since np.where evaluates all values for both True and False + # branches. + return jnp.where( + x == 0, + 1.0 / (4.0 * tau**2), + jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), + ) + @staticmethod @partial(jax.jit, static_argnames=["ny"]) def _project( @@ -179,7 +247,7 @@ def _project( @partial(jax.jit, static_argnames=["nx"]) def _back_project( y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike - ) -> ArrayLike: + ) -> snp.Array: r"""Compute X-ray back projection. Args: @@ -361,7 +429,7 @@ def _project_single( return proj @staticmethod - def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike: + def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array: r""" Args: proj: Input (set of) projection(s). diff --git a/scico/test/linop/xray/test_xray_2d.py b/scico/test/linop/xray/test_xray_2d.py index deab4d32..3a9d7488 100644 --- a/scico/test/linop/xray/test_xray_2d.py +++ b/scico/test/linop/xray/test_xray_2d.py @@ -1,5 +1,6 @@ import numpy as np +import jax import jax.numpy as jnp import pytest @@ -7,6 +8,7 @@ import scico import scico.linop from scico.linop.xray import XRayTransform2D +from scico.metric import psnr @pytest.mark.filterwarnings("error") @@ -81,3 +83,33 @@ def test_matched_adjoint(): angles = np.linspace(0, np.pi, n_projection, endpoint=False) A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx) assert scico.linop.valid_adjoint(A, A.T, eps=1e-5) + + +@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)]) +@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) +def test_fbp(dx, det_count_factor): + N = 256 + x_gt = np.zeros((N, N), dtype=np.float32) + N4 = N // 4 + x_gt[N4:-N4, N4:-N4] = 1.0 + + det_count = int(det_count_factor * N) + n_proj = 360 + angles = np.linspace(0, np.pi, n_proj, endpoint=False) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx) + y = A(x_gt) + x_fbp = A.fbp(y) + assert psnr(x_gt, x_fbp) > 28 + + +def test_fbp_jit(): + N = 64 + x_gt = np.ones((N, N), dtype=np.float32) + + det_count = N + n_proj = 90 + angles = np.linspace(0, np.pi, n_proj, endpoint=False) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count) + y = A(x_gt) + fbp = jax.jit(A.fbp) + x_fbp = fbp(y)