From 4d6c632de5bf6eb37e4921acd1f67089fcad1237 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Oct 2024 13:44:13 -0600 Subject: [PATCH 1/9] Add filtered back projection for 2D projector --- docs/source/references.bib | 8 ++++ scico/linop/xray/_xray.py | 63 ++++++++++++++++++++++++++++++ scico/test/linop/xray/test_xray.py | 17 ++++++++ 3 files changed, 88 insertions(+) diff --git a/docs/source/references.bib b/docs/source/references.bib index 257f24287..e612e36e1 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -396,6 +396,13 @@ @Article {jin-2017-unet doi = {10.1109/TIP.2017.2713099} } +@Book {kak-1988-principles, + author = {Avinash C. Kak and Malcolm Slaney}, + title = {Principles of Computerized Tomographic Imaging}, + publisher = {IEEE Press}, + year = 1988 +} + @TechReport {kamilov-2016-minimizing, author = {Ulugbek S. Kamilov}, title = {Minimizing Isotropic Total Variation without @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn pages = {3142--3155} } + @Article {zhang-2021-plug, author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 770bf627d..127b4e7ec 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -124,6 +124,69 @@ def back_project(self, y: ArrayLike) -> snp.Array: """Compute X-ray back projection""" return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) + def fbp(self, y: ArrayLike) -> snp.Array: + """Compute Filter Back Projection inverse of projection. + + Compute the Filter Back Projection inverse by filtering each row + of the sinogram with the filter defined in (61) in + :cite:`kak-1988-principles` and then back projecting. The + projection angles are assumed to be evenly spaced: poor results + may be obtained if this assumption is violated. + + Args: + y: Input projection, (num_angles, N). + + Returns: + Filtered Back Projection inverse of projection. + """ + + N = y.shape[1] + nvec = np.arange(N) - (N - 1) // 2 + dx = np.sqrt(self.dx[0] * self.dx[1]) # type: ignore + h = XRayTransform2D._ramp_filter(nvec, 1.0 / dx) + + # Apply ramp filter in the frequency domain, padding to avoid + # boundary effects + hf = snp.fft.fft(h.reshape(1, -1), n=2 * N - 1, axis=1) + yf = snp.fft.fft(y, n=2 * N - 1, axis=1) + hy = snp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ + :, (N - 1) // 2 : -(N - 1) // 2 + ].real.astype(snp.float32) + + x = (snp.pi / y.shape[0]) * self.back_project(hy) + # Mask out the invalid region of the reconstruction + gi, gj = snp.mgrid[: x.shape[0], : x.shape[1]] + x = snp.where( + snp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, + x, + 0.0, + ) + return x + + @staticmethod + def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: + """Compute coefficients of ramp filter used in FBP. + + Compute coefficients of ramp filter used in FBP, as defined in + (61) in :cite:`kak-1988-principles`. + + Args: + x: Sampling locations at which to compute filter coefficients. + tau: Sampling rate. + + Returns: + Spatial-domain coefficients of ramp filter. + """ + # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0) + # is included to avoid division by zero warnings when x == 1 + # since np.where evaluates all values for both True and False + # branches. + return snp.where( + x == 0, + 1.0 / (4.0 * tau**2), + snp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), + ) + @staticmethod @partial(jax.jit, static_argnames=["ny"]) def _project( diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index cd7c0dcdd..4aab2e928 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -6,6 +6,7 @@ import scico from scico.linop.xray import XRayTransform2D, XRayTransform3D +from scico.metric import psnr @pytest.mark.filterwarnings("error") @@ -71,6 +72,22 @@ def test_apply_adjoint(): assert y.shape[1] == det_count +@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)]) +@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) +def test_fbp(dx, det_count_factor): + N = 256 + x_gt = np.zeros((256, 256), dtype=np.float32) + x_gt[64:-64, 64:-64] = 1.0 + + det_count = int(det_count_factor * N) + n_proj = 360 + angles = np.linspace(0, np.pi, n_proj) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx) + y = A(x_gt) + x_fbp = A.fbp(y) + assert psnr(x_gt, x_fbp) > 28 + + def test_3d_scaling(): x = jnp.zeros((4, 4, 1)) x = x.at[1:3, 1:3, 0].set(1.0) From 1497ee69eddfb271c08f68ac1c51f65103936cee Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Oct 2024 14:08:29 -0600 Subject: [PATCH 2/9] Update change summary --- CHANGES.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 8a7864843..e5aa718b1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased) ---------------------------- • Significant changes to ``linop.xray.astra`` API. +• Rename integrated 2D X-ray transform class to + ``linop.xray.XRayTransform2D`` and add filtered back projection method + ``fbp``. • New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. @@ -17,8 +20,8 @@ Version 0.0.6 (unreleased) • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.31. -• Support ``flax`` versions 0.8.0 to 0.8.3. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33. +• Support ``flax`` versions 0.8.0 to 0.9.0. From 97f3b05e8dbc11625f12363f9d556a6d4e2626b1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Oct 2024 14:10:20 -0600 Subject: [PATCH 3/9] Docstring fixes --- scico/linop/xray/_xray.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 127b4e7ec..02eb2f2f5 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -125,10 +125,10 @@ def back_project(self, y: ArrayLike) -> snp.Array: return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) def fbp(self, y: ArrayLike) -> snp.Array: - """Compute Filter Back Projection inverse of projection. + """Compute filtered back projection (FBP) inverse of projection. - Compute the Filter Back Projection inverse by filtering each row - of the sinogram with the filter defined in (61) in + Compute the filtered back projection inverse by filtering each + row of the sinogram with the filter defined in (61) in :cite:`kak-1988-principles` and then back projecting. The projection angles are assumed to be evenly spaced: poor results may be obtained if this assumption is violated. @@ -137,7 +137,7 @@ def fbp(self, y: ArrayLike) -> snp.Array: y: Input projection, (num_angles, N). Returns: - Filtered Back Projection inverse of projection. + FBP inverse of projection. """ N = y.shape[1] From ff1e2354661bd4aae25f37fe8307216d8c0c7eed Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 6 Oct 2024 11:16:02 -0600 Subject: [PATCH 4/9] Resolve errors in jitting method --- scico/linop/xray/_xray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 02eb2f2f5..bd1cc0ea6 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -141,8 +141,8 @@ def fbp(self, y: ArrayLike) -> snp.Array: """ N = y.shape[1] - nvec = np.arange(N) - (N - 1) // 2 - dx = np.sqrt(self.dx[0] * self.dx[1]) # type: ignore + nvec = snp.arange(N) - (N - 1) // 2 + dx = snp.sqrt(self.dx[0] * self.dx[1]) # type: ignore h = XRayTransform2D._ramp_filter(nvec, 1.0 / dx) # Apply ramp filter in the frequency domain, padding to avoid From 5d37894cd15e1d649e3d78a164728096b86f098e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 9 Oct 2024 07:59:26 -0600 Subject: [PATCH 5/9] Some improvements --- scico/linop/xray/_xray.py | 7 ++++--- scico/test/linop/xray/test_xray.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index bd1cc0ea6..ff17e58c0 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -125,13 +125,14 @@ def back_project(self, y: ArrayLike) -> snp.Array: return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) def fbp(self, y: ArrayLike) -> snp.Array: - """Compute filtered back projection (FBP) inverse of projection. + r"""Compute filtered back projection (FBP) inverse of projection. Compute the filtered back projection inverse by filtering each row of the sinogram with the filter defined in (61) in :cite:`kak-1988-principles` and then back projecting. The - projection angles are assumed to be evenly spaced: poor results - may be obtained if this assumption is violated. + projection angles are assumed to be evenly spaced in + :math:`[0, \pi)`; reconstruction quality may be poor if + this assumption is violated. Args: y: Input projection, (num_angles, N). diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index 4aab2e928..44e8c05b6 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -81,7 +81,7 @@ def test_fbp(dx, det_count_factor): det_count = int(det_count_factor * N) n_proj = 360 - angles = np.linspace(0, np.pi, n_proj) + angles = np.linspace(0, np.pi, n_proj, endpoint=False) A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx) y = A(x_gt) x_fbp = A.fbp(y) From f49430d4c373a19349449b9f0039773ea9d5e3de Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 14 Oct 2024 14:59:58 -0600 Subject: [PATCH 6/9] Clean up --- scico/linop/xray/_xray.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index ff17e58c0..0656a43bc 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -142,23 +142,22 @@ def fbp(self, y: ArrayLike) -> snp.Array: """ N = y.shape[1] - nvec = snp.arange(N) - (N - 1) // 2 - dx = snp.sqrt(self.dx[0] * self.dx[1]) # type: ignore - h = XRayTransform2D._ramp_filter(nvec, 1.0 / dx) + nvec = jnp.arange(N) - (N - 1) // 2 + h = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) # Apply ramp filter in the frequency domain, padding to avoid # boundary effects - hf = snp.fft.fft(h.reshape(1, -1), n=2 * N - 1, axis=1) - yf = snp.fft.fft(y, n=2 * N - 1, axis=1) - hy = snp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ + hf = jnp.fft.fft(h, n=2 * N - 1, axis=1) + yf = jnp.fft.fft(y, n=2 * N - 1, axis=1) + hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ :, (N - 1) // 2 : -(N - 1) // 2 - ].real.astype(snp.float32) + ].real.astype(jnp.float32) - x = (snp.pi / y.shape[0]) * self.back_project(hy) + x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.back_project(hy) # Mask out the invalid region of the reconstruction - gi, gj = snp.mgrid[: x.shape[0], : x.shape[1]] - x = snp.where( - snp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, + gi, gj = jnp.mgrid[: x.shape[0], : x.shape[1]] + x = jnp.where( + jnp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, x, 0.0, ) @@ -182,10 +181,10 @@ def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: # is included to avoid division by zero warnings when x == 1 # since np.where evaluates all values for both True and False # branches. - return snp.where( + return jnp.where( x == 0, 1.0 / (4.0 * tau**2), - snp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), + jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), ) @staticmethod From 5c9c974324b5bb94cc924ad015fc3bb74a22aa18 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 14 Oct 2024 18:56:58 -0600 Subject: [PATCH 7/9] Improve tests --- scico/test/linop/xray/test_xray.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index 44e8c05b6..9efddffab 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -1,5 +1,6 @@ import numpy as np +import jax import jax.numpy as jnp import pytest @@ -76,8 +77,9 @@ def test_apply_adjoint(): @pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) def test_fbp(dx, det_count_factor): N = 256 - x_gt = np.zeros((256, 256), dtype=np.float32) - x_gt[64:-64, 64:-64] = 1.0 + x_gt = np.zeros((N, N), dtype=np.float32) + N4 = N // 4 + x_gt[N4:-N4, N4:-N4] = 1.0 det_count = int(det_count_factor * N) n_proj = 360 @@ -88,6 +90,19 @@ def test_fbp(dx, det_count_factor): assert psnr(x_gt, x_fbp) > 28 +def test_fbp_jit(): + N = 64 + x_gt = np.ones((N, N), dtype=np.float32) + + det_count = N + n_proj = 90 + angles = np.linspace(0, np.pi, n_proj, endpoint=False) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count) + y = A(x_gt) + fbp = jax.jit(A.fbp) + x_fbp = fbp(y) + + def test_3d_scaling(): x = jnp.zeros((4, 4, 1)) x = x.at[1:3, 1:3, 0].set(1.0) From b6bded809e58a72491d7b047a4a11edf22711519 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 15 Oct 2024 15:07:23 -0600 Subject: [PATCH 8/9] Improve mask mechanism --- scico/linop/xray/_xray.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 2884dd00c..bd68ad61a 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -114,6 +114,9 @@ def __init__( self.y0 = y0 self.dy = 1.0 + self.fbp_filter: Optional[snp.Array] = None + self.fbp_mask: Optional[snp.Array] = None + super().__init__( input_shape=self.input_shape, input_dtype=np.float32, @@ -155,27 +158,28 @@ def fbp(self, y: ArrayLike) -> snp.Array: Returns: FBP inverse of projection. """ - N = y.shape[1] - nvec = jnp.arange(N) - (N - 1) // 2 - h = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) + + if self.fbp_filter is None: + nvec = jnp.arange(N) - (N - 1) // 2 + self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) + + if self.fbp_mask is None: + unit_sino = jnp.ones(self.output_shape, dtype=np.float32) + # Threshold is multiplied by 0.99... fudge factor to account for numerical errors + # in back projection. + self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore # Apply ramp filter in the frequency domain, padding to avoid # boundary effects + h = self.fbp_filter hf = jnp.fft.fft(h, n=2 * N - 1, axis=1) yf = jnp.fft.fft(y, n=2 * N - 1, axis=1) hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ :, (N - 1) // 2 : -(N - 1) // 2 ].real.astype(jnp.float32) - x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.back_project(hy) - # Mask out the invalid region of the reconstruction - gi, gj = jnp.mgrid[: x.shape[0], : x.shape[1]] - x = jnp.where( - jnp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, - x, - 0.0, - ) + x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore return x @staticmethod @@ -242,7 +246,7 @@ def _project( @partial(jax.jit, static_argnames=["nx"]) def _back_project( y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike - ) -> ArrayLike: + ) -> snp.Array: r"""Compute X-ray back projection. Args: @@ -424,7 +428,7 @@ def _project_single( return proj @staticmethod - def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike: + def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array: r""" Args: proj: Input (set of) projection(s). From 62f0a50ded78d3f03e2f43d8abf635c1b458f429 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 15 Oct 2024 15:20:50 -0600 Subject: [PATCH 9/9] Improve docs --- scico/linop/xray/_xray.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index bd68ad61a..83bd84624 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -68,10 +68,10 @@ def __init__( corresponds to summing along antidiagonals. x0: (x, y) position of the corner of the pixel `im[0,0]`. By default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`. - dx: Image pixel side length in x- and y-direction. Must be - set so that the width of a projected pixel is never - larger than 1.0. By default, [:math:`\sqrt{2}/2`, - :math:`\sqrt{2}/2`]. + dx: Image pixel side length in x- and y-direction (axis 0 and + 1 respectively). Must be set so that the width of a + projected pixel is never larger than 1.0. By default, + [:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`]. y0: Location of the edge of the first detector bin. By default, `-det_count / 2` det_count: Number of elements in detector. If ``None``, @@ -150,7 +150,8 @@ def fbp(self, y: ArrayLike) -> snp.Array: :cite:`kak-1988-principles` and then back projecting. The projection angles are assumed to be evenly spaced in :math:`[0, \pi)`; reconstruction quality may be poor if - this assumption is violated. + this assumption is violated. Poor quality reconstructions should + also be expected when `dx[0]` and `dx[1]` are not equal. Args: y: Input projection, (num_angles, N).