Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filtered back projection for 2D projector #558

Merged
merged 11 commits into from
Oct 18, 2024
Merged
7 changes: 5 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• Rename integrated 2D X-ray transform class to
``linop.xray.XRayTransform2D`` and add filtered back projection method
``fbp``.
• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.
• New functional ``functional.IsotropicTVNorm`` and faster implementation
of ``functional.AnisotropicTVNorm``.
Expand All @@ -17,8 +20,8 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.31.
• Support ``flax`` versions 0.8.0 to 0.8.3.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33.
• Support ``flax`` versions 0.8.0 to 0.9.0.



Expand Down
8 changes: 8 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ @Article {jin-2017-unet
doi = {10.1109/TIP.2017.2713099}
}

@Book {kak-1988-principles,
author = {Avinash C. Kak and Malcolm Slaney},
title = {Principles of Computerized Tomographic Imaging},
publisher = {IEEE Press},
year = 1988
}

@TechReport {kamilov-2016-minimizing,
author = {Ulugbek S. Kamilov},
title = {Minimizing Isotropic Total Variation without
Expand Down Expand Up @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn
pages = {3142--3155}
}


@Article {zhang-2021-plug,
author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and
Zhang, Lei and Van Gool, Luc and Timofte, Radu},
Expand Down
80 changes: 74 additions & 6 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def __init__(
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
dx: Image pixel side length in x- and y-direction. Must be
set so that the width of a projected pixel is never
larger than 1.0. By default, [:math:`\sqrt{2}/2`,
:math:`\sqrt{2}/2`].
dx: Image pixel side length in x- and y-direction (axis 0 and
1 respectively). Must be set so that the width of a
projected pixel is never larger than 1.0. By default,
[:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`].
y0: Location of the edge of the first detector bin. By
default, `-det_count / 2`
det_count: Number of elements in detector. If ``None``,
Expand Down Expand Up @@ -114,6 +114,9 @@ def __init__(
self.y0 = y0
self.dy = 1.0

self.fbp_filter: Optional[snp.Array] = None
self.fbp_mask: Optional[snp.Array] = None

super().__init__(
input_shape=self.input_shape,
input_dtype=np.float32,
Expand All @@ -139,6 +142,71 @@ def back_project(self, y: ArrayLike) -> snp.Array:
"""
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

def fbp(self, y: ArrayLike) -> snp.Array:
r"""Compute filtered back projection (FBP) inverse of projection.

Compute the filtered back projection inverse by filtering each
row of the sinogram with the filter defined in (61) in
:cite:`kak-1988-principles` and then back projecting. The
projection angles are assumed to be evenly spaced in
:math:`[0, \pi)`; reconstruction quality may be poor if
this assumption is violated. Poor quality reconstructions should
also be expected when `dx[0]` and `dx[1]` are not equal.

Args:
y: Input projection, (num_angles, N).

Returns:
FBP inverse of projection.
"""
N = y.shape[1]

if self.fbp_filter is None:
nvec = jnp.arange(N) - (N - 1) // 2
self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1)

if self.fbp_mask is None:
unit_sino = jnp.ones(self.output_shape, dtype=np.float32)
# Threshold is multiplied by 0.99... fudge factor to account for numerical errors
# in back projection.
self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore

# Apply ramp filter in the frequency domain, padding to avoid
# boundary effects
h = self.fbp_filter
hf = jnp.fft.fft(h, n=2 * N - 1, axis=1)
yf = jnp.fft.fft(y, n=2 * N - 1, axis=1)
hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[
:, (N - 1) // 2 : -(N - 1) // 2
].real.astype(jnp.float32)

x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore
return x

@staticmethod
def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array:
"""Compute coefficients of ramp filter used in FBP.

Compute coefficients of ramp filter used in FBP, as defined in
(61) in :cite:`kak-1988-principles`.

Args:
x: Sampling locations at which to compute filter coefficients.
tau: Sampling rate.

Returns:
Spatial-domain coefficients of ramp filter.
"""
# The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0)
# is included to avoid division by zero warnings when x == 1
# since np.where evaluates all values for both True and False
# branches.
return jnp.where(
x == 0,
1.0 / (4.0 * tau**2),
jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0),
)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(
Expand Down Expand Up @@ -179,7 +247,7 @@ def _project(
@partial(jax.jit, static_argnames=["nx"])
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
) -> ArrayLike:
) -> snp.Array:
r"""Compute X-ray back projection.

Args:
Expand Down Expand Up @@ -361,7 +429,7 @@ def _project_single(
return proj

@staticmethod
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike:
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array:
r"""
Args:
proj: Input (set of) projection(s).
Expand Down
32 changes: 32 additions & 0 deletions scico/test/linop/xray/test_xray_2d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np

import jax
import jax.numpy as jnp

import pytest

import scico
import scico.linop
from scico.linop.xray import XRayTransform2D
from scico.metric import psnr


@pytest.mark.filterwarnings("error")
Expand Down Expand Up @@ -81,3 +83,33 @@ def test_matched_adjoint():
angles = np.linspace(0, np.pi, n_projection, endpoint=False)
A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx)
assert scico.linop.valid_adjoint(A, A.T, eps=1e-5)


@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)])
@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0])
def test_fbp(dx, det_count_factor):
N = 256
x_gt = np.zeros((N, N), dtype=np.float32)
N4 = N // 4
x_gt[N4:-N4, N4:-N4] = 1.0

det_count = int(det_count_factor * N)
n_proj = 360
angles = np.linspace(0, np.pi, n_proj, endpoint=False)
A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx)
y = A(x_gt)
x_fbp = A.fbp(y)
assert psnr(x_gt, x_fbp) > 28


def test_fbp_jit():
N = 64
x_gt = np.ones((N, N), dtype=np.float32)

det_count = N
n_proj = 90
angles = np.linspace(0, np.pi, n_proj, endpoint=False)
A = XRayTransform2D(x_gt.shape, angles, det_count=det_count)
y = A(x_gt)
fbp = jax.jit(A.fbp)
x_fbp = fbp(y)
Loading