Skip to content

Commit

Permalink
Fix boundary handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Oct 15, 2024
1 parent bff2ee3 commit 9212886
Showing 1 changed file with 46 additions and 13 deletions.
59 changes: 46 additions & 13 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,16 @@ def _project(
"""
nx = im.shape
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

# Handle out of bounds indices by setting weight to zero
weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)
y = (
jnp.zeros((len(angles), ny))
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
.add(im * weights_valid)
)

y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights))
weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid)

return y

Expand All @@ -174,14 +173,15 @@ def _back_project(
"""
ny = y.shape[1]
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)
# Handle out of bounds indices by setting weight to zero
weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)

# the idea: [y[0, inds[0]], y[1, inds[1]], ...]
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0)
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0)

weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
HTy = HTy + jnp.sum(
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0
)

return HTy
Expand Down Expand Up @@ -385,7 +385,7 @@ def _back_project_single(

@staticmethod
def _calc_weights(
input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0
input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0
) -> snp.Array:
# pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)
x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...)
Expand All @@ -403,13 +403,46 @@ def _calc_weights(
left_edge = Px - w / 2
to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w)
ul_ind = jnp.floor(left_edge).astype("int32")
ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap

ul_weight = to_next[0] * to_next[1] * (1 / w**2)
ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2)
ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2)
lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2)

# set weights to zero out of bounds
ul_weight = jnp.where(
(ul_ind[0] >= 0)
* (ul_ind[0] < det_shape[0])
* (ul_ind[1] >= 0)
* (ul_ind[1] < det_shape[1]),
ul_weight,
0.0,
)
ur_weight = jnp.where(
(ul_ind[0] + 1 >= 0)
* (ul_ind[0] + 1 < det_shape[0])
* (ul_ind[1] >= 0)
* (ul_ind[1] < det_shape[1]),
ur_weight,
0.0,
)
ll_weight = jnp.where(
(ul_ind[0] >= 0)
* (ul_ind[0] < det_shape[0])
* (ul_ind[1] + 1 >= 0)
* (ul_ind[1] + 1 < det_shape[1]),
ll_weight,
0.0,
)
lr_weight = jnp.where(
(ul_ind[0] + 1 >= 0)
* (ul_ind[0] + 1 < det_shape[0])
* (ul_ind[1] + 1 >= 0)
* (ul_ind[1] + 1 < det_shape[1]),
lr_weight,
0.0,
)

return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight

@staticmethod
Expand Down

0 comments on commit 9212886

Please sign in to comment.