diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 770bf627..beb55e4e 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -142,17 +142,16 @@ def _project( """ nx = im.shape inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) - + # Handle out of bounds indices by setting weight to zero + weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) y = ( jnp.zeros((len(angles), ny)) .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] - .add(im * weights) + .add(im * weights_valid) ) - y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights)) + weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) + y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid) return y @@ -174,14 +173,15 @@ def _back_project( """ ny = y.shape[1] inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) + # Handle out of bounds indices by setting weight to zero + weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) # the idea: [y[0, inds[0]], y[1, inds[1]], ...] - HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0) + HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0) + + weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) HTy = HTy + jnp.sum( - y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0 + y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0 ) return HTy @@ -385,7 +385,7 @@ def _back_project_single( @staticmethod def _calc_weights( - input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0 + input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0 ) -> snp.Array: # pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5) x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...) @@ -403,13 +403,46 @@ def _calc_weights( left_edge = Px - w / 2 to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w) ul_ind = jnp.floor(left_edge).astype("int32") - ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap ul_weight = to_next[0] * to_next[1] * (1 / w**2) ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2) ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2) lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2) + # set weights to zero out of bounds + ul_weight = jnp.where( + (ul_ind[0] >= 0) + * (ul_ind[0] < det_shape[0]) + * (ul_ind[1] >= 0) + * (ul_ind[1] < det_shape[1]), + ul_weight, + 0.0, + ) + ur_weight = jnp.where( + (ul_ind[0] + 1 >= 0) + * (ul_ind[0] + 1 < det_shape[0]) + * (ul_ind[1] >= 0) + * (ul_ind[1] < det_shape[1]), + ur_weight, + 0.0, + ) + ll_weight = jnp.where( + (ul_ind[0] >= 0) + * (ul_ind[0] < det_shape[0]) + * (ul_ind[1] + 1 >= 0) + * (ul_ind[1] + 1 < det_shape[1]), + ll_weight, + 0.0, + ) + lr_weight = jnp.where( + (ul_ind[0] + 1 >= 0) + * (ul_ind[0] + 1 < det_shape[0]) + * (ul_ind[1] + 1 >= 0) + * (ul_ind[1] + 1 < det_shape[1]), + lr_weight, + 0.0, + ) + return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight @staticmethod