From d3193ea478f75ff8eb4e41df76aa0416a63107ea Mon Sep 17 00:00:00 2001 From: Mike McCann <57153404+Michael-T-McCann@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:44:52 -0600 Subject: [PATCH] Fix XRayTransform boundary conditions (#561) * Add failing tests * Fix boundary handling --- scico/linop/xray/_xray.py | 58 ++++++++++++---- .../xray/{test_xray.py => test_xray_2d.py} | 54 ++++----------- scico/test/linop/xray/test_xray_3d.py | 66 +++++++++++++++++++ 3 files changed, 123 insertions(+), 55 deletions(-) rename scico/test/linop/xray/{test_xray.py => test_xray_2d.py} (54%) create mode 100644 scico/test/linop/xray/test_xray_3d.py diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 4fe40893..f3067e9a 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -158,20 +158,20 @@ def _project( """ nx = im.shape inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) # avoid incompatible types in the .add (scatter operation) weights = weights.astype(im.dtype) + # Handle out of bounds indices by setting weight to zero + weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) y = ( jnp.zeros((len(angles), ny), dtype=im.dtype) .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] - .add(im * weights) + .add(im * weights_valid) ) - y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights)) + weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) + y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid) return y @@ -194,14 +194,15 @@ def _back_project( """ ny = y.shape[1] inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) + # Handle out of bounds indices by setting weight to zero + weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) # the idea: [y[0, inds[0]], y[1, inds[1]], ...] - HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0) + HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0) + + weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) HTy = HTy + jnp.sum( - y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0 + y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0 ) return HTy @@ -401,7 +402,7 @@ def _back_project_single( @staticmethod def _calc_weights( - input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0 + input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0 ) -> snp.Array: # pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5) x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...) @@ -419,13 +420,46 @@ def _calc_weights( left_edge = Px - w / 2 to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w) ul_ind = jnp.floor(left_edge).astype("int32") - ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap ul_weight = to_next[0] * to_next[1] * (1 / w**2) ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2) ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2) lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2) + # set weights to zero out of bounds + ul_weight = jnp.where( + (ul_ind[0] >= 0) + * (ul_ind[0] < det_shape[0]) + * (ul_ind[1] >= 0) + * (ul_ind[1] < det_shape[1]), + ul_weight, + 0.0, + ) + ur_weight = jnp.where( + (ul_ind[0] + 1 >= 0) + * (ul_ind[0] + 1 < det_shape[0]) + * (ul_ind[1] >= 0) + * (ul_ind[1] < det_shape[1]), + ur_weight, + 0.0, + ) + ll_weight = jnp.where( + (ul_ind[0] >= 0) + * (ul_ind[0] < det_shape[0]) + * (ul_ind[1] + 1 >= 0) + * (ul_ind[1] + 1 < det_shape[1]), + ll_weight, + 0.0, + ) + lr_weight = jnp.where( + (ul_ind[0] + 1 >= 0) + * (ul_ind[0] + 1 < det_shape[0]) + * (ul_ind[1] + 1 >= 0) + * (ul_ind[1] + 1 < det_shape[1]), + lr_weight, + 0.0, + ) + return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight @staticmethod diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray_2d.py similarity index 54% rename from scico/test/linop/xray/test_xray.py rename to scico/test/linop/xray/test_xray_2d.py index b9e12776..deab4d32 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray_2d.py @@ -5,7 +5,8 @@ import pytest import scico -from scico.linop.xray import XRayTransform2D, XRayTransform3D +import scico.linop +from scico.linop.xray import XRayTransform2D @pytest.mark.filterwarnings("error") @@ -71,45 +72,12 @@ def test_apply_adjoint(): assert y.shape[1] == det_count -def test_3d_scaling(): - x = jnp.zeros((4, 4, 1)) - x = x.at[1:3, 1:3, 0].set(1.0) - - input_shape = x.shape - output_shape = x.shape[:2] - - # default spacing - M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) - H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) - - # fmt: off - truth = jnp.array( - [[[0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0]]] - ) # fmt: on - np.testing.assert_allclose(H @ x, truth) - - # bigger voxels in the x (first index) direction - M = XRayTransform3D.matrices_from_euler_angles( - input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0] - ) - H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) - # fmt: off - truth = jnp.array( - [[[0. , 0.5, 0.5, 0. ], - [0. , 0.5, 0.5, 0. ], - [0. , 0.5, 0.5, 0. ], - [0. , 0.5, 0.5, 0. ]]] - ) # fmt: on - np.testing.assert_allclose(H @ x, truth) - - # bigger detector pixels in the x (first index) direction - M = XRayTransform3D.matrices_from_euler_angles( - input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0] - ) - H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) - # fmt: off - truth = None # fmt: on # TODO: Check this case more closely. - # np.testing.assert_allclose(H @ x, truth) +def test_matched_adjoint(): + """See https://github.com/lanl/scico/issues/560.""" + N = 16 + det_count = int(N * 1.05 / np.sqrt(2.0)) + dx = 1.0 / np.sqrt(2) + n_projection = 3 + angles = np.linspace(0, np.pi, n_projection, endpoint=False) + A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx) + assert scico.linop.valid_adjoint(A, A.T, eps=1e-5) diff --git a/scico/test/linop/xray/test_xray_3d.py b/scico/test/linop/xray/test_xray_3d.py new file mode 100644 index 00000000..d96217a4 --- /dev/null +++ b/scico/test/linop/xray/test_xray_3d.py @@ -0,0 +1,66 @@ +import numpy as np + +import jax.numpy as jnp + +import scico.linop +from scico.linop.xray import XRayTransform3D + + +def test_matched_adjoint(): + """See https://github.com/lanl/scico/issues/560.""" + N = 16 + det_count = int(N * 1.05 / np.sqrt(2.0)) + n_projection = 3 + + input_shape = (N, N, N) + det_shape = (det_count, det_count) + + M = XRayTransform3D.matrices_from_euler_angles( + input_shape, det_shape, "X", np.linspace(0, np.pi, n_projection, endpoint=False) + ) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + + assert scico.linop.valid_adjoint(H, H.T, eps=1e-5) + + +def test_scaling(): + x = jnp.zeros((4, 4, 1)) + x = x.at[1:3, 1:3, 0].set(1.0) + + input_shape = x.shape + det_shape = x.shape[:2] + + # default spacing + M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, "X", [0.0]) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + # fmt: off + truth = jnp.array( + [[[0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0]]] + ) # fmt: on + np.testing.assert_allclose(H @ x, truth) + + # bigger voxels in the x (first index) direction + M = XRayTransform3D.matrices_from_euler_angles( + input_shape, det_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0] + ) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + # fmt: off + truth = jnp.array( + [[[0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ]]] + ) # fmt: on + np.testing.assert_allclose(H @ x, truth) + + # bigger detector pixels in the x (first index) direction + M = XRayTransform3D.matrices_from_euler_angles( + input_shape, det_shape, "X", [0.0], det_spacing=[2.0, 1.0] + ) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + # fmt: off + truth = None # fmt: on # TODO: Check this case more closely. + # np.testing.assert_allclose(H @ x, truth)