From 5c83511bd9cbadffd266b88c1172066592346fcf Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 8 Dec 2023 14:22:58 -0700 Subject: [PATCH] Add `Operator` versions of `DiagonalStack` and `VerticalStack` (#477) * Minor docs edits * Improve docs * Rename parameters * Update change summary * Fix erroneous find-and-replace * Minor improvements * Bug fix * Add vertical and diagonal stack operators * Fix markup * Update change summary * Improve docs * Fix __all__ * Resolve some docs issues * Improve docstrings * LinearOperator stacks derived from Operator stacks * Resolve circular imports * Resolve type checking issues * Minor docstring edit * Docstring edits * Rename test files to avoid name collision * Fix comment * Docs consistency * Address PR review comments * Ensure all list entries are linear operators * Use correct exception type * Change exception type in test * Add tests * Add input checks and tests * Address codefactor complaint --- CHANGES.rst | 13 +- docs/source/include/functional.rst | 4 +- scico/linop/_stack.py | 242 +++++---------- scico/operator/__init__.py | 5 +- scico/operator/_stack.py | 284 ++++++++++++++++++ .../{test_stack.py => test_linop_stack.py} | 46 +-- scico/test/operator/test_op_stack.py | 157 ++++++++++ scico/test/{ => operator}/test_operator.py | 0 8 files changed, 559 insertions(+), 192 deletions(-) create mode 100644 scico/operator/_stack.py rename scico/test/linop/{test_stack.py => test_linop_stack.py} (76%) create mode 100644 scico/test/operator/test_op_stack.py rename scico/test/{ => operator}/test_operator.py (100%) diff --git a/CHANGES.rst b/CHANGES.rst index 7792c36b7..2d68d754a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,11 +9,14 @@ Version 0.0.5 (unreleased) • New functionals ``functional.AnisotropicTVNorm`` and ``functional.ProximalAverage`` with proximal operator approximations. • New integrated Radon/X-ray transform ``linop.XRayTransform``. +• New operators ``operator.DiagonalStack`` and ``operator.VerticalStack``. • Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes to ``XRayTransform``. • Rename ``AbelProjector`` to ``AbelTransform``. • Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. +• Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and + ``linop.VerticalStack``. • Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.20. • Support ``flax`` versions up to 0.7.5. • Use ``orbax`` for checkpointing ``flax`` models. @@ -23,10 +26,10 @@ Version 0.0.5 (unreleased) Version 0.0.4 (2023-08-03) ---------------------------- -• Add new `Function` class for representing array-to-array mappings with more +• Add new ``Function`` class for representing array-to-array mappings with more than one input. • Add new methods and a function for computing Jacobian-vector products for - `Operator` objects. + ``Operator`` objects. • Add new proximal ADMM solvers. • Add new ADMM subproblem solvers for problems involving a sum-of-convolutions operator. @@ -35,7 +38,7 @@ Version 0.0.4 (2023-08-03) • Enable diagnostics for ML training loops. • Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.14. • Change required packages and version numbers, including more recent version - for `flax`. + for ``flax``. • Drop support for Python 3.7. • Add support for 3D tomographic projection with the ASTRA Toolbox. @@ -45,8 +48,8 @@ Version 0.0.3 (2022-09-21) ---------------------------- • Change required packages and version numbers, including more recent version - requirements for `numpy`, `scipy`, `svmbir`, and `ray`. -• Package `bm4d` removed from main requirements list due to issue #342. + requirements for ``numpy``, ``scipy``, ``svmbir``, and ``ray``. +• Package ``bm4d`` removed from main requirements list due to issue #342. • Support ``jaxlib`` versions 0.3.0 to 0.3.15 and ``jax`` versions 0.3.0 to 0.3.17. • Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules diff --git a/docs/source/include/functional.rst b/docs/source/include/functional.rst index 7af5f34f5..462e74b54 100644 --- a/docs/source/include/functional.rst +++ b/docs/source/include/functional.rst @@ -91,11 +91,11 @@ in terms of the proximal operators of the :math:`f_i` .. math:: \mathrm{prox}_f(\mb{x}, \lambda) = - \begin{bmatrix} + \begin{pmatrix} \mathrm{prox}_{f_1}(\mb{x}_1, \lambda) \\ \vdots \\ \mathrm{prox}_{f_N}(\mb{x}_N, \lambda) \\ - \end{bmatrix} \;. + \end{pmatrix} \;. Separable Functionals are implemented in the :class:`.SeparableFunctional` class. Separable functionals naturally accept :class:`.BlockArray` inputs and return the prox as a :class:`.BlockArray`. diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 878ebd247..8d8ef68ea 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -5,196 +5,120 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Stack of linear operators class.""" +"""Stack of linear operators classes.""" from __future__ import annotations import operator from functools import partial -from typing import List, Optional, Sequence, Tuple, Union - -import numpy as np - -from typing_extensions import TypeGuard +from typing import Optional, Sequence, Union import scico.numpy as snp from scico.numpy import Array, BlockArray -from scico.numpy.util import is_nested -from scico.typing import BlockShape, Shape - -from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar - - -def collapse_shapes( - shapes: Sequence[Union[Shape, BlockShape]], allow_collapse=True -) -> Tuple[Union[Shape, BlockShape], bool]: - """Decides whether to collapse a sequence of shapes and returns the collapsed - shape and a boolean indicating whether the shape was collapsed.""" - - if is_collapsible(shapes) and allow_collapse: - return (len(shapes), *shapes[0]), True +from scico.operator._stack import DiagonalStack as DStack +from scico.operator._stack import VerticalStack as VStack - if is_blockable(shapes): - return shapes, False +from ._linop import LinearOperator, _wrap_add_sub - raise ValueError( - "Combining these shapes would result in a twice-nested BlockArray, which is not supported." - ) +class VerticalStack(VStack, LinearOperator): + r"""A vertical stack of linear operators. -def is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool: - """Return ``True`` if the a list of shapes represent arrays that can - be stacked, i.e., they are all the same.""" - return all(s == shapes[0] for s in shapes) + Given linear operators :math:`A_1, A_2, \dots, A_N`, create the + linear operator - -def is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]: - """Return ``True`` if the list of shapes represent arrays that can be - combined into a :class:`BlockArray`, i.e., none are nested.""" - return not any(is_nested(s) for s in shapes) - - -class VerticalStack(LinearOperator): - """A vertical stack of LinearOperators.""" + .. math:: + H = + \begin{pmatrix} + A_1 \\ + A_2 \\ + \vdots \\ + A_N \\ + \end{pmatrix} \qquad + \text{such that} \qquad + H \mb{x} + = + \begin{pmatrix} + A_1(\mb{x}) \\ + A_2(\mb{x}) \\ + \vdots \\ + A_N(\mb{x}) \\ + \end{pmatrix} \;. + """ def __init__( self, - ops: List[LinearOperator], - collapse: Optional[bool] = True, + ops: Sequence[LinearOperator], + collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): r""" Args: - ops: Operators to stack. - collapse: If ``True`` and the output would be a + ops: Linear operators to stack. + collapse_output: If ``True`` and the output would be a :class:`BlockArray` with shape ((m, n, ...), (m, n, ...), ...), the output is instead a :class:`jax.Array` with shape (S, m, n, ...) where S is the length of `ops`. - Defaults to ``True``. - jit: see `jit` in :class:`LinearOperator`. + jit: See `jit` in :class:`LinearOperator`. """ - VerticalStack.check_if_stackable(ops) + if not all(isinstance(op, LinearOperator) for op in ops): + raise TypeError("All elements of ops must be of type LinearOperator.") - self.ops = ops - self.collapse = collapse - - output_shapes = tuple(op.output_shape for op in ops) - self.collapsible = is_collapsible(output_shapes) - - if self.collapsible and self.collapse: - output_shape = (len(ops),) + output_shapes[0] # collapse to jax array - else: - output_shape = output_shapes - - super().__init__( - input_shape=ops[0].input_shape, - output_shape=output_shape, # type: ignore - input_dtype=ops[0].input_dtype, - output_dtype=ops[0].output_dtype, - jit=jit, - **kwargs, - ) - - @staticmethod - def check_if_stackable(ops: List[LinearOperator]): - """Check that input ops are suitable for stack creation.""" - if not isinstance(ops, (list, tuple)): - raise ValueError("Expected a list of LinearOperator.") - - input_shapes = [op.shape[1] for op in ops] - if not all(input_shapes[0] == s for s in input_shapes): - raise ValueError( - "Expected all LinearOperators to have the same input shapes, " - f"but got {input_shapes}." - ) - - input_dtypes = [op.input_dtype for op in ops] - if not all(input_dtypes[0] == s for s in input_dtypes): - raise ValueError( - "Expected all LinearOperators to have the same input dtype, " - f"but got {input_dtypes}." - ) - - if any([is_nested(op.shape[0]) for op in ops]): - raise ValueError("Cannot stack LinearOperators with nested output shapes.") - - output_dtypes = [op.output_dtype for op in ops] - if not np.all(output_dtypes[0] == s for s in output_dtypes): - raise ValueError("Expected all LinearOperators to have the same output dtype.") - - def _eval(self, x: Array) -> Union[Array, BlockArray]: - if self.collapsible and self.collapse: - return snp.stack([op @ x for op in self.ops]) - return BlockArray([op @ x for op in self.ops]) + super().__init__(ops=ops, collapse_output=collapse_output, jit=jit, **kwargs) def _adj(self, y: Union[Array, BlockArray]) -> Array: # type: ignore - return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) - - def scale_ops(self, scalars: Array): - """Scale component linear operators. - - Return a copy of `self` with each operator scaled by the - corresponding entry in `scalars`. - - Args: - scalars: List or array of scalars to use. - """ - if len(scalars) != len(self.ops): - raise ValueError("Expected `scalars` to be the same length as self.ops.") - - return VerticalStack([a * op for a, op in zip(scalars, self.ops)], collapse=self.collapse) + return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) # type: ignore @partial(_wrap_add_sub, op=operator.add) def __add__(self, other): # add another VerticalStack of the same shape return VerticalStack( - [op1 + op2 for op1, op2 in zip(self.ops, other.ops)], collapse=self.collapse + [op1 + op2 for op1, op2 in zip(self.ops, other.ops)], + collapse_output=self.collapse_output, ) @partial(_wrap_add_sub, op=operator.sub) def __sub__(self, other): # subtract another VerticalStack of the same shape return VerticalStack( - [op1 - op2 for op1, op2 in zip(self.ops, other.ops)], collapse=self.collapse + [op1 - op2 for op1, op2 in zip(self.ops, other.ops)], + collapse_output=self.collapse_output, ) - @_wrap_mul_div_scalar - def __mul__(self, scalar): - return VerticalStack([scalar * op for op in self.ops], collapse=self.collapse) - - @_wrap_mul_div_scalar - def __rmul__(self, scalar): - return VerticalStack([scalar * op for op in self.ops], collapse=self.collapse) - - @_wrap_mul_div_scalar - def __truediv__(self, scalar): - return VerticalStack([op / scalar for op in self.ops], collapse=self.collapse) +class DiagonalStack(DStack, LinearOperator): + r"""A diagonal stack of linear operators. -class DiagonalStack(LinearOperator): - r"""A diagonal stack of LinearOperators. - - Given operators :math:`A_1, A_2, \dots, A_N`, creates the operator - :math:`H` such that + Given linear operators :math:`A_1, A_2, \dots, A_N`, create the + linear operator .. math:: + H = + \begin{pmatrix} + A_1 & 0 & \ldots & 0\\ + 0 & A_2 & \ldots & 0\\ + \vdots & \vdots & \ddots & \vdots\\ + 0 & 0 & \ldots & A_N \\ + \end{pmatrix} \qquad + \text{such that} \qquad + H \begin{pmatrix} - A_1(\mathbf{x}_1) \\ - A_2(\mathbf{x}_2) \\ + \mb{x}_1 \\ + \mb{x}_2 \\ \vdots \\ - A_N(\mathbf{x}_N) \\ + \mb{x}_N \\ \end{pmatrix} - = H + = \begin{pmatrix} - \mathbf{x}_1 \\ - \mathbf{x}_2 \\ + A_1(\mb{x}_1) \\ + A_2(\mb{x}_2) \\ \vdots \\ - \mathbf{x}_N \\ + A_N(\mb{x}_N) \\ \end{pmatrix} \;. - By default, if the inputs :math:`\mathbf{x}_1, \mathbf{x}_2, \dots, - \mathbf{x}_N` all have the same (possibly nested) shape, `S`, this + By default, if the inputs :math:`\mb{x}_1, \mb{x}_2, \dots, + \mb{x}_N` all have the same (possibly nested) shape, `S`, this operator will work on the stack, i.e., have an input shape of `(N, *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`, this operator will work on the block concatenation, i.e., @@ -204,51 +128,35 @@ class DiagonalStack(LinearOperator): def __init__( self, - ops: List[LinearOperator], - allow_input_collapse: Optional[bool] = True, - allow_output_collapse: Optional[bool] = True, + ops: Sequence[LinearOperator], + collapse_input: Optional[bool] = True, + collapse_output: Optional[bool] = True, jit: bool = True, **kwargs, ): """ Args: - op: Operators to form into a block matrix. - allow_input_collapse: If ``True``, inputs are expected to be + ops: Operators to stack. + collapse_input: If ``True``, inputs are expected to be stacked along the first dimension when possible. - allow_output_collapse: If ``True``, the output will be + collapse_output: If ``True``, the output will be stacked along the first dimension when possible. - jit: see `jit` in :class:`LinearOperator`. + jit: See `jit` in :class:`LinearOperator`. """ - self.ops = ops - - input_shape, self.collapse_input = collapse_shapes( - tuple(op.input_shape for op in ops), - allow_input_collapse, - ) - - output_shape, self.collapse_output = collapse_shapes( - tuple(op.output_shape for op in ops), - allow_output_collapse, - ) + if not all(isinstance(op, LinearOperator) for op in ops): + raise TypeError("All elements of ops must be of type LinearOperator.") super().__init__( - input_shape=input_shape, - output_shape=output_shape, - input_dtype=ops[0].input_dtype, - output_dtype=ops[0].output_dtype, + ops=ops, + collapse_input=collapse_input, + collapse_output=collapse_output, jit=jit, **kwargs, ) - def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: - result = tuple(op @ x_n for op, x_n in zip(self.ops, x)) - if self.collapse_output: - return snp.stack(result) - return snp.blockarray(result) - def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type: ignore - result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y)) + result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y)) # type: ignore if self.collapse_input: return snp.stack(result) return snp.blockarray(result) diff --git a/scico/operator/__init__.py b/scico/operator/__init__.py index f9835ef30..fee512369 100644 --- a/scico/operator/__init__.py +++ b/scico/operator/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2022 by SCICO Developers +# Copyright (C) 2021-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,10 +13,13 @@ from ._operator import Operator from .biconvolve import BiConvolve from ._func import operator_from_function, Abs, Angle, Exp +from ._stack import DiagonalStack, VerticalStack __all__ = [ "Operator", "BiConvolve", + "DiagonalStack", + "VerticalStack", "operator_from_function", "Abs", "Angle", diff --git a/scico/operator/_stack.py b/scico/operator/_stack.py new file mode 100644 index 000000000..000baf95c --- /dev/null +++ b/scico/operator/_stack.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2023 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Stack of operators classes.""" + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +from typing_extensions import TypeGuard + +import scico.numpy as snp +from scico.numpy import Array, BlockArray +from scico.numpy.util import is_nested +from scico.typing import BlockShape, Shape + +from ._operator import Operator, _wrap_mul_div_scalar + + +def collapse_shapes( + shapes: Sequence[Union[Shape, BlockShape]], allow_collapse=True +) -> Tuple[Union[Shape, BlockShape], bool]: + """Compute the collapsed representation of a sequence of shapes. + + Decide whether to collapse a sequence of shapes, returning either + the sequence of shapes or a collapsed shape, and a boolean indicating + whether the shape was collapsed.""" + + if is_collapsible(shapes) and allow_collapse: + return (len(shapes), *shapes[0]), True + + if is_blockable(shapes): + return shapes, False + + raise ValueError( + "Combining these shapes would result in a twice-nested BlockArray, which is not supported." + ) + + +def is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool: + """Determine whether a sequence of shapes can be collapsed. + + Return ``True`` if the a list of shapes represent arrays that can + be stacked, i.e., they are all the same.""" + return all(s == shapes[0] for s in shapes) + + +def is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]: + """Determine whether a sequence of shapes could be a :class:`BlockArray` shape. + + Return ``True`` if the sequence of shapes represent arrays that can + be combined into a :class:`BlockArray`, i.e., none are nested.""" + return not any(is_nested(s) for s in shapes) + + +class VerticalStack(Operator): + r"""A vertical stack of operators. + + Given operators :math:`A_1, A_2, \dots, A_N`, create the operator + :math:`H` such that + + .. math:: + H(\mb{x}) + = + \begin{pmatrix} + A_1(\mb{x}) \\ + A_2(\mb{x}) \\ + \vdots \\ + A_N(\mb{x}) \\ + \end{pmatrix} \;. + """ + + def __init__( + self, + ops: Sequence[Operator], + collapse_output: Optional[bool] = True, + jit: bool = True, + **kwargs, + ): + r""" + Args: + ops: Operators to stack. + collapse_output: If ``True`` and the output would be a + :class:`BlockArray` with shape ((m, n, ...), (m, n, ...), + ...), the output is instead a :class:`jax.Array` with + shape (S, m, n, ...) where S is the length of `ops`. + jit: See `jit` in :class:`Operator`. + """ + VerticalStack.check_if_stackable(ops) + + self.ops = ops + self.collapse_output = collapse_output + + output_shapes = tuple(op.output_shape for op in ops) + self.output_collapsible = is_collapsible(output_shapes) + + if self.output_collapsible and self.collapse_output: + output_shape = (len(ops),) + output_shapes[0] # collapse to jax array + else: + output_shape = output_shapes + + super().__init__( + input_shape=ops[0].input_shape, + output_shape=output_shape, # type: ignore + input_dtype=ops[0].input_dtype, + output_dtype=ops[0].output_dtype, + jit=jit, + **kwargs, + ) + + @staticmethod + def check_if_stackable(ops: Sequence[Operator]): + """Check that input ops are suitable for stack creation.""" + if not isinstance(ops, (list, tuple)): + raise TypeError("Expected a list of Operator.") + + input_shapes = [op.shape[1] for op in ops] + if not all(input_shapes[0] == s for s in input_shapes): + raise ValueError( + "Expected all Operators to have the same input shapes, " f"but got {input_shapes}." + ) + + input_dtypes = [op.input_dtype for op in ops] + if not all(input_dtypes[0] == s for s in input_dtypes): + raise ValueError( + "Expected all Operators to have the same input dtype, " f"but got {input_dtypes}." + ) + + if any([is_nested(op.shape[0]) for op in ops]): + raise ValueError("Cannot stack Operators with nested output shapes.") + + output_dtypes = [op.output_dtype for op in ops] + if not np.all(output_dtypes[0] == s for s in output_dtypes): + raise ValueError("Expected all Operators to have the same output dtype.") + + def _eval(self, x: Array) -> Union[Array, BlockArray]: + if self.output_collapsible and self.collapse_output: + return snp.stack([op(x) for op in self.ops]) + return BlockArray([op(x) for op in self.ops]) + + def scale_ops(self, scalars: Array): + """Scale component operators. + + Return a copy of `self` with each operator scaled by the + corresponding entry in `scalars`. + + Args: + scalars: List or array of scalars to use. + """ + if len(scalars) != len(self.ops): + raise ValueError("Expected scalars to be the same length as self.ops.") + + return self.__class__( + [a * op for a, op in zip(scalars, self.ops)], collapse_output=self.collapse_output + ) + + def __add__(self, other): + # add another VerticalStack of the same shape + return self.__class__( + [op1 + op2 for op1, op2 in zip(self.ops, other.ops)], + collapse_output=self.collapse_output, + ) + + def __sub__(self, other): + # subtract another VerticalStack of the same shape + return self.__class__( + [op1 - op2 for op1, op2 in zip(self.ops, other.ops)], + collapse_output=self.collapse_output, + ) + + @_wrap_mul_div_scalar + def __mul__(self, scalar): + return self.__class__( + [scalar * op for op in self.ops], collapse_output=self.collapse_output + ) + + @_wrap_mul_div_scalar + def __rmul__(self, scalar): + return self.__class__( + [scalar * op for op in self.ops], collapse_output=self.collapse_output + ) + + @_wrap_mul_div_scalar + def __truediv__(self, scalar): + return self.__class__( + [op / scalar for op in self.ops], collapse_output=self.collapse_output + ) + + +class DiagonalStack(Operator): + r"""A diagonal stack of operators. + + Given operators :math:`A_1, A_2, \dots, A_N`, create the operator + :math:`H` such that + + .. math:: + H \left( + \begin{pmatrix} + \mb{x}_1 \\ + \mb{x}_2 \\ + \vdots \\ + \mb{x}_N \\ + \end{pmatrix} \right) + = + \begin{pmatrix} + A_1(\mb{x}_1) \\ + A_2(\mb{x}_2) \\ + \vdots \\ + A_N(\mb{x}_N) \\ + \end{pmatrix} \;. + + By default, if the inputs :math:`\mb{x}_1, \mb{x}_2, \dots, + \mb{x}_N` all have the same (possibly nested) shape, `S`, this + operator will work on the stack, i.e., have an input shape of `(N, + *S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`, + this operator will work on the block concatenation, i.e., + have an input shape of `(S1, S2, ..., SN)`. The same holds for the + output shape. + """ + + def __init__( + self, + ops: Sequence[Operator], + collapse_input: Optional[bool] = True, + collapse_output: Optional[bool] = True, + jit: bool = True, + **kwargs, + ): + """ + Args: + ops: Operators to stack. + collapse_input: If ``True``, inputs are expected to be + stacked along the first dimension when possible. + collapse_output: If ``True``, the output will be + stacked along the first dimension when possible. + jit: See `jit` in :class:`Operator`. + + """ + DiagonalStack.check_if_stackable(ops) + + self.ops = ops + + input_shape, self.collapse_input = collapse_shapes( + tuple(op.input_shape for op in ops), + collapse_input, + ) + output_shape, self.collapse_output = collapse_shapes( + tuple(op.output_shape for op in ops), + collapse_output, + ) + + super().__init__( + input_shape=input_shape, + output_shape=output_shape, + input_dtype=ops[0].input_dtype, + output_dtype=ops[0].output_dtype, + jit=jit, + **kwargs, + ) + + @staticmethod + def check_if_stackable(ops: Sequence[Operator]): + """Check that input ops are suitable for stack creation.""" + if not isinstance(ops, (list, tuple)): + raise TypeError("Expected a list of Operator.") + + if any([is_nested(op.shape[0]) for op in ops]): + raise ValueError("Cannot stack Operators with nested output shapes.") + + output_dtypes = [op.output_dtype for op in ops] + if not np.all(output_dtypes[0] == s for s in output_dtypes): + raise ValueError("Expected all Operators to have the same output dtype.") + + def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: + result = tuple(op(x_n) for op, x_n in zip(self.ops, x)) + if self.collapse_output: + return snp.stack(result) + return snp.blockarray(result) diff --git a/scico/test/linop/test_stack.py b/scico/test/linop/test_linop_stack.py similarity index 76% rename from scico/test/linop/test_stack.py rename to scico/test/linop/test_linop_stack.py index cd59b73ba..f3604b16a 100644 --- a/scico/test/linop/test_stack.py +++ b/scico/test/linop/test_linop_stack.py @@ -6,6 +6,7 @@ import scico.numpy as snp from scico.linop import Convolve, DiagonalStack, Identity, Sum, VerticalStack +from scico.operator import Abs from scico.test.linop.test_linop import adjoint_test @@ -16,9 +17,14 @@ def setup_method(self, method): @pytest.mark.parametrize("jit", [False, True]) def test_construct(self, jit): # requires a list of LinearOperators - I = Identity((42,)) - with pytest.raises(ValueError): - H = VerticalStack(I, jit=jit) + Id = Identity((42,)) + with pytest.raises(TypeError): + H = VerticalStack(Id, jit=jit) + + # requires all list elements to be LinearOperators + A = Abs((42,)) + with pytest.raises(TypeError): + H = VerticalStack((A, Id), jit=jit) # checks input sizes A = Identity((3, 2)) @@ -38,7 +44,7 @@ def test_construct(self, jit): assert np.allclose(y[0], A @ x) assert np.allclose(y[1], B @ x) - # by default, collapse to jax array when possible + # by default, collapse_output to jax array when possible A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) H = VerticalStack([A, B], jit=jit) @@ -53,42 +59,42 @@ def test_construct(self, jit): # let user turn off collapsing A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) - H = VerticalStack([A, B], collapse=False, jit=jit) + H = VerticalStack([A, B], collapse_output=False, jit=jit) x = np.ones((7, 11)) y = H @ x assert y.shape == ((8, 12), (8, 12)) - @pytest.mark.parametrize("collapse", [False, True]) + @pytest.mark.parametrize("collapse_output", [False, True]) @pytest.mark.parametrize("jit", [False, True]) - def test_adjoint(self, collapse, jit): + def test_adjoint(self, collapse_output, jit): # general case A = Convolve(snp.ones((3, 3)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) - H = VerticalStack([A, B], collapse=collapse, jit=jit) + H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) adjoint_test(H, self.key) # collapsable case A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) - H = VerticalStack([A, B], collapse=collapse, jit=jit) + H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) adjoint_test(H, self.key) - @pytest.mark.parametrize("collapse", [False, True]) + @pytest.mark.parametrize("collapse_output", [False, True]) @pytest.mark.parametrize("jit", [False, True]) - def test_algebra(self, collapse, jit): + def test_algebra(self, collapse_output, jit): # adding A = Convolve(snp.ones((2, 2)), (7, 11)) B = Convolve(snp.ones((2, 2)), (7, 11)) - H = VerticalStack([A, B], collapse=collapse, jit=jit) + H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) A = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) B = Convolve(snp.array(np.random.rand(2, 2)), (7, 11)) - G = VerticalStack([A, B], collapse=collapse, jit=jit) + G = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) x = np.ones((7, 11)) S = H + G - # test correctness of adding + # test correctness of addition assert S.output_shape == H.output_shape assert S.input_shape == H.input_shape np.testing.assert_allclose((S @ x)[0], (H @ x + G @ x)[0]) @@ -111,6 +117,12 @@ def test_algebra(self, collapse, jit): class TestBlockDiagonalLinearOperator: + def test_construct(self): + Id = Identity((42,)) + A = Abs((42,)) + with pytest.raises(TypeError): + H = DiagonalStack((A, Id)) + def test_apply(self): S1 = (3, 4) S2 = (3, 5) @@ -124,7 +136,7 @@ def test_apply(self): y = H @ x y_expected = snp.blockarray((snp.ones(S1), 2 * snp.ones(S2), snp.sum(snp.ones(S3)))) - assert y == y_expected + np.testing.assert_equal(y, y_expected) def test_adjoint(self): S1 = (3, 4) @@ -155,7 +167,7 @@ def test_input_collapse(self): H = DiagonalStack((A1, A2)) assert H.input_shape == (2, *S) - H = DiagonalStack((A1, A2), allow_input_collapse=False) + H = DiagonalStack((A1, A2), collapse_input=False) assert H.input_shape == (S, S) def test_output_collapse(self): @@ -167,5 +179,5 @@ def test_output_collapse(self): H = DiagonalStack((A1, A2)) assert H.output_shape == (2, *S1) - H = DiagonalStack((A1, A2), allow_output_collapse=False) + H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (S1, S1) diff --git a/scico/test/operator/test_op_stack.py b/scico/test/operator/test_op_stack.py new file mode 100644 index 000000000..695bebe95 --- /dev/null +++ b/scico/test/operator/test_op_stack.py @@ -0,0 +1,157 @@ +import numpy as np + +import jax + +import pytest + +import scico.numpy as snp +from scico.operator import Abs, DiagonalStack, Operator, VerticalStack + +TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x))) +TestOpB = Operator( + input_shape=(3, 4), output_shape=(6, 4), eval_fn=lambda x: snp.concatenate((x, x)) +) +TestOpC = Operator( + input_shape=(3, 4), output_shape=(6, 4), eval_fn=lambda x: snp.concatenate((x, 2 * x)) +) + + +class TestVerticalStack: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [False, True]) + def test_construct(self, jit): + # requires a list of Operators + A = Abs((42,)) + with pytest.raises(TypeError): + H = VerticalStack(A, jit=jit) + + # checks input sizes + A = Abs((3, 2)) + B = Abs((7, 2)) + with pytest.raises(ValueError): + H = VerticalStack([A, B], jit=jit) + + # in general, returns a BlockArray + A = TestOpA + B = TestOpB + H = VerticalStack([A, B], jit=jit) + x = np.ones((3, 4)) + y = H(x) + assert y.shape == ((2, 3, 4), (6, 4)) + + # ... result should be [A@x, B@x] + assert np.allclose(y[0], A(x)) + assert np.allclose(y[1], B(x)) + + # by default, collapse_output to jax array when possible + A = TestOpB + B = TestOpB + H = VerticalStack([A, B], jit=jit) + x = np.ones((3, 4)) + y = H(x) + assert y.shape == (2, 6, 4) + + # ... result should be [A@x, B@x] + assert np.allclose(y[0], A(x)) + assert np.allclose(y[1], B(x)) + + # let user turn off collapsing + A = TestOpA + B = TestOpA + H = VerticalStack([A, B], collapse_output=False, jit=jit) + x = np.ones((3, 4)) + y = H(x) + assert y.shape == ((2, 3, 4), (2, 3, 4)) + + @pytest.mark.parametrize("collapse_output", [False, True]) + @pytest.mark.parametrize("jit", [False, True]) + def test_algebra(self, collapse_output, jit): + # adding + A = TestOpB + B = TestOpB + H = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) + + A = TestOpC + B = TestOpC + G = VerticalStack([A, B], collapse_output=collapse_output, jit=jit) + + x = np.ones((3, 4)) + S = H + G + + # test correctness of addition + assert S.output_shape == H.output_shape + assert S.input_shape == H.input_shape + np.testing.assert_allclose((S(x))[0], (H(x) + G(x))[0]) + np.testing.assert_allclose((S(x))[1], (H(x) + G(x))[1]) + + # result of adding two conformable stacks should be a stack + assert isinstance(S, VerticalStack) + assert isinstance(H - G, VerticalStack) + + # scalar multiplication + assert isinstance(1.0 * H, VerticalStack) + + # op scaling + scalars = [2.0, 3.0] + y1 = S(x) + S2 = S.scale_ops(scalars) + y2 = S2(x) + + np.testing.assert_allclose(scalars[0] * y1[0], y2[0]) + + +class TestBlockDiagonalOperator: + def test_construct(self): + # requires a list of Operators + A = Abs((8,)) + with pytest.raises(TypeError): + H = VerticalStack(A) + + # no nested output shapes + A = Abs(((8,), (10,))) + with pytest.raises(ValueError): + H = VerticalStack((A, A)) + + # output dtypes must be the same + A = Abs(input_shape=(8,), input_dtype=snp.float32) + B = Abs(input_shape=(8,), input_dtype=snp.int32) + with pytest.raises(ValueError): + H = VerticalStack((A, B)) + + def test_apply(self): + S1 = (3, 4) + S2 = (3, 5) + S3 = (2, 2) + A1 = Abs(S1) + A2 = 2 * Abs(S2) + A3 = Abs(S3) + H = DiagonalStack((A1, A2, A3)) + + x = snp.ones((S1, S2, S3)) + y = H(x) + y_expected = snp.blockarray((snp.ones(S1), 2 * snp.ones(S2), snp.sum(snp.ones(S3)))) + + np.testing.assert_equal(y, y_expected) + + def test_input_collapse(self): + S = (3, 4) + A1 = TestOpA + A2 = TestOpB + + H = DiagonalStack((A1, A2)) + assert H.input_shape == (2, *S) + + H = DiagonalStack((A1, A2), collapse_input=False) + assert H.input_shape == (S, S) + + def test_output_collapse(self): + A1 = TestOpB + A2 = TestOpC + + H = DiagonalStack((A1, A2)) + assert H.output_shape == (2, *A1.output_shape) + + H = DiagonalStack((A1, A2), collapse_output=False) + assert H.output_shape == (A1.output_shape, A1.output_shape) diff --git a/scico/test/test_operator.py b/scico/test/operator/test_operator.py similarity index 100% rename from scico/test/test_operator.py rename to scico/test/operator/test_operator.py