diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 6d24d397..7486c6e8 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -189,7 +189,9 @@ def _back_project( @staticmethod @partial(jax.jit, static_argnames=["nx"]) @partial(jax.vmap, in_axes=(None, None, None, 0, None)) - def _calc_weights(x0: ArrayLike, dx: ArrayLike, nx: int, angle: float, y0: float) -> snp.Array: + def _calc_weights( + x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float + ) -> snp.Array: """ Args: @@ -249,10 +251,6 @@ class XRayTransform3D(LinearOperator): :meth:`XRayTransform3D.matrices_from_euler_angles` can help to make these geometry arrays. - - - - """ def __init__( @@ -268,7 +266,7 @@ def __init__( det_shape: Shape of detector. """ - self.input_shape = input_shape + self.input_shape: Shape = input_shape self.matrices = matrices self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape)