Skip to content

Commit

Permalink
Add Operator versions of DiagonalStack and VerticalStack (#477)
Browse files Browse the repository at this point in the history
* Minor docs edits

* Improve docs

* Rename parameters

* Update change summary

* Fix erroneous find-and-replace

* Minor improvements

* Bug fix

* Add vertical and diagonal stack operators

* Fix markup

* Update change summary

* Improve docs

* Fix __all__

* Resolve some docs issues

* Improve docstrings

* LinearOperator stacks derived from Operator stacks

* Resolve circular imports

* Resolve type checking issues

* Minor docstring edit

* Docstring edits

* Rename test files to avoid name collision

* Fix comment

* Docs consistency

* Address PR review comments

* Ensure all list entries are linear operators

* Use correct exception type

* Change exception type in test

* Add tests

* Add input checks and tests

* Address codefactor complaint
  • Loading branch information
bwohlberg authored Dec 8, 2023
1 parent e378d7a commit 5c83511
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 192 deletions.
13 changes: 8 additions & 5 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ Version 0.0.5 (unreleased)
• New functionals ``functional.AnisotropicTVNorm`` and
``functional.ProximalAverage`` with proximal operator approximations.
• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• New operators ``operator.DiagonalStack`` and ``operator.VerticalStack``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
to ``XRayTransform``.
• Rename ``AbelProjector`` to ``AbelTransform``.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Rename some ``__init__`` parameters of ``linop.DiagonalStack`` and
``linop.VerticalStack``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.20.
• Support ``flax`` versions up to 0.7.5.
• Use ``orbax`` for checkpointing ``flax`` models.
Expand All @@ -23,10 +26,10 @@ Version 0.0.5 (unreleased)
Version 0.0.4 (2023-08-03)
----------------------------

• Add new `Function` class for representing array-to-array mappings with more
• Add new ``Function`` class for representing array-to-array mappings with more
than one input.
• Add new methods and a function for computing Jacobian-vector products for
`Operator` objects.
``Operator`` objects.
• Add new proximal ADMM solvers.
• Add new ADMM subproblem solvers for problems involving a sum-of-convolutions
operator.
Expand All @@ -35,7 +38,7 @@ Version 0.0.4 (2023-08-03)
• Enable diagnostics for ML training loops.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.14.
• Change required packages and version numbers, including more recent version
for `flax`.
for ``flax``.
• Drop support for Python 3.7.
• Add support for 3D tomographic projection with the ASTRA Toolbox.

Expand All @@ -45,8 +48,8 @@ Version 0.0.3 (2022-09-21)
----------------------------

• Change required packages and version numbers, including more recent version
requirements for `numpy`, `scipy`, `svmbir`, and `ray`.
• Package `bm4d` removed from main requirements list due to issue #342.
requirements for ``numpy``, ``scipy``, ``svmbir``, and ``ray``.
• Package ``bm4d`` removed from main requirements list due to issue #342.
• Support ``jaxlib`` versions 0.3.0 to 0.3.15 and ``jax`` versions
0.3.0 to 0.3.17.
• Rename linear operators in ``radon_astra`` and ``radon_svmbir`` modules
Expand Down
4 changes: 2 additions & 2 deletions docs/source/include/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ in terms of the proximal operators of the :math:`f_i`
.. math::
\mathrm{prox}_f(\mb{x}, \lambda)
=
\begin{bmatrix}
\begin{pmatrix}
\mathrm{prox}_{f_1}(\mb{x}_1, \lambda) \\
\vdots \\
\mathrm{prox}_{f_N}(\mb{x}_N, \lambda) \\
\end{bmatrix} \;.
\end{pmatrix} \;.
Separable Functionals are implemented in the :class:`.SeparableFunctional` class. Separable functionals naturally accept :class:`.BlockArray` inputs and return the prox as a :class:`.BlockArray`.

Expand Down
242 changes: 75 additions & 167 deletions scico/linop/_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,196 +5,120 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Stack of linear operators class."""
"""Stack of linear operators classes."""

from __future__ import annotations

import operator
from functools import partial
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np

from typing_extensions import TypeGuard
from typing import Optional, Sequence, Union

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import is_nested
from scico.typing import BlockShape, Shape

from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar


def collapse_shapes(
shapes: Sequence[Union[Shape, BlockShape]], allow_collapse=True
) -> Tuple[Union[Shape, BlockShape], bool]:
"""Decides whether to collapse a sequence of shapes and returns the collapsed
shape and a boolean indicating whether the shape was collapsed."""

if is_collapsible(shapes) and allow_collapse:
return (len(shapes), *shapes[0]), True
from scico.operator._stack import DiagonalStack as DStack
from scico.operator._stack import VerticalStack as VStack

if is_blockable(shapes):
return shapes, False
from ._linop import LinearOperator, _wrap_add_sub

raise ValueError(
"Combining these shapes would result in a twice-nested BlockArray, which is not supported."
)

class VerticalStack(VStack, LinearOperator):
r"""A vertical stack of linear operators.
def is_collapsible(shapes: Sequence[Union[Shape, BlockShape]]) -> bool:
"""Return ``True`` if the a list of shapes represent arrays that can
be stacked, i.e., they are all the same."""
return all(s == shapes[0] for s in shapes)
Given linear operators :math:`A_1, A_2, \dots, A_N`, create the
linear operator

def is_blockable(shapes: Sequence[Union[Shape, BlockShape]]) -> TypeGuard[Union[Shape, BlockShape]]:
"""Return ``True`` if the list of shapes represent arrays that can be
combined into a :class:`BlockArray`, i.e., none are nested."""
return not any(is_nested(s) for s in shapes)


class VerticalStack(LinearOperator):
"""A vertical stack of LinearOperators."""
.. math::
H =
\begin{pmatrix}
A_1 \\
A_2 \\
\vdots \\
A_N \\
\end{pmatrix} \qquad
\text{such that} \qquad
H \mb{x}
=
\begin{pmatrix}
A_1(\mb{x}) \\
A_2(\mb{x}) \\
\vdots \\
A_N(\mb{x}) \\
\end{pmatrix} \;.
"""

def __init__(
self,
ops: List[LinearOperator],
collapse: Optional[bool] = True,
ops: Sequence[LinearOperator],
collapse_output: Optional[bool] = True,
jit: bool = True,
**kwargs,
):
r"""
Args:
ops: Operators to stack.
collapse: If ``True`` and the output would be a
ops: Linear operators to stack.
collapse_output: If ``True`` and the output would be a
:class:`BlockArray` with shape ((m, n, ...), (m, n, ...),
...), the output is instead a :class:`jax.Array` with
shape (S, m, n, ...) where S is the length of `ops`.
Defaults to ``True``.
jit: see `jit` in :class:`LinearOperator`.
jit: See `jit` in :class:`LinearOperator`.
"""
VerticalStack.check_if_stackable(ops)
if not all(isinstance(op, LinearOperator) for op in ops):
raise TypeError("All elements of ops must be of type LinearOperator.")

self.ops = ops
self.collapse = collapse

output_shapes = tuple(op.output_shape for op in ops)
self.collapsible = is_collapsible(output_shapes)

if self.collapsible and self.collapse:
output_shape = (len(ops),) + output_shapes[0] # collapse to jax array
else:
output_shape = output_shapes

super().__init__(
input_shape=ops[0].input_shape,
output_shape=output_shape, # type: ignore
input_dtype=ops[0].input_dtype,
output_dtype=ops[0].output_dtype,
jit=jit,
**kwargs,
)

@staticmethod
def check_if_stackable(ops: List[LinearOperator]):
"""Check that input ops are suitable for stack creation."""
if not isinstance(ops, (list, tuple)):
raise ValueError("Expected a list of LinearOperator.")

input_shapes = [op.shape[1] for op in ops]
if not all(input_shapes[0] == s for s in input_shapes):
raise ValueError(
"Expected all LinearOperators to have the same input shapes, "
f"but got {input_shapes}."
)

input_dtypes = [op.input_dtype for op in ops]
if not all(input_dtypes[0] == s for s in input_dtypes):
raise ValueError(
"Expected all LinearOperators to have the same input dtype, "
f"but got {input_dtypes}."
)

if any([is_nested(op.shape[0]) for op in ops]):
raise ValueError("Cannot stack LinearOperators with nested output shapes.")

output_dtypes = [op.output_dtype for op in ops]
if not np.all(output_dtypes[0] == s for s in output_dtypes):
raise ValueError("Expected all LinearOperators to have the same output dtype.")

def _eval(self, x: Array) -> Union[Array, BlockArray]:
if self.collapsible and self.collapse:
return snp.stack([op @ x for op in self.ops])
return BlockArray([op @ x for op in self.ops])
super().__init__(ops=ops, collapse_output=collapse_output, jit=jit, **kwargs)

def _adj(self, y: Union[Array, BlockArray]) -> Array: # type: ignore
return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)])

def scale_ops(self, scalars: Array):
"""Scale component linear operators.
Return a copy of `self` with each operator scaled by the
corresponding entry in `scalars`.
Args:
scalars: List or array of scalars to use.
"""
if len(scalars) != len(self.ops):
raise ValueError("Expected `scalars` to be the same length as self.ops.")

return VerticalStack([a * op for a, op in zip(scalars, self.ops)], collapse=self.collapse)
return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) # type: ignore

@partial(_wrap_add_sub, op=operator.add)
def __add__(self, other):
# add another VerticalStack of the same shape
return VerticalStack(
[op1 + op2 for op1, op2 in zip(self.ops, other.ops)], collapse=self.collapse
[op1 + op2 for op1, op2 in zip(self.ops, other.ops)],
collapse_output=self.collapse_output,
)

@partial(_wrap_add_sub, op=operator.sub)
def __sub__(self, other):
# subtract another VerticalStack of the same shape
return VerticalStack(
[op1 - op2 for op1, op2 in zip(self.ops, other.ops)], collapse=self.collapse
[op1 - op2 for op1, op2 in zip(self.ops, other.ops)],
collapse_output=self.collapse_output,
)

@_wrap_mul_div_scalar
def __mul__(self, scalar):
return VerticalStack([scalar * op for op in self.ops], collapse=self.collapse)

@_wrap_mul_div_scalar
def __rmul__(self, scalar):
return VerticalStack([scalar * op for op in self.ops], collapse=self.collapse)

@_wrap_mul_div_scalar
def __truediv__(self, scalar):
return VerticalStack([op / scalar for op in self.ops], collapse=self.collapse)

class DiagonalStack(DStack, LinearOperator):
r"""A diagonal stack of linear operators.
class DiagonalStack(LinearOperator):
r"""A diagonal stack of LinearOperators.
Given operators :math:`A_1, A_2, \dots, A_N`, creates the operator
:math:`H` such that
Given linear operators :math:`A_1, A_2, \dots, A_N`, create the
linear operator
.. math::
H =
\begin{pmatrix}
A_1 & 0 & \ldots & 0\\
0 & A_2 & \ldots & 0\\
\vdots & \vdots & \ddots & \vdots\\
0 & 0 & \ldots & A_N \\
\end{pmatrix} \qquad
\text{such that} \qquad
H
\begin{pmatrix}
A_1(\mathbf{x}_1) \\
A_2(\mathbf{x}_2) \\
\mb{x}_1 \\
\mb{x}_2 \\
\vdots \\
A_N(\mathbf{x}_N) \\
\mb{x}_N \\
\end{pmatrix}
= H
=
\begin{pmatrix}
\mathbf{x}_1 \\
\mathbf{x}_2 \\
A_1(\mb{x}_1) \\
A_2(\mb{x}_2) \\
\vdots \\
\mathbf{x}_N \\
A_N(\mb{x}_N) \\
\end{pmatrix} \;.
By default, if the inputs :math:`\mathbf{x}_1, \mathbf{x}_2, \dots,
\mathbf{x}_N` all have the same (possibly nested) shape, `S`, this
By default, if the inputs :math:`\mb{x}_1, \mb{x}_2, \dots,
\mb{x}_N` all have the same (possibly nested) shape, `S`, this
operator will work on the stack, i.e., have an input shape of `(N,
*S)`. If the inputs have distinct shapes, `S1`, `S2`, ..., `SN`,
this operator will work on the block concatenation, i.e.,
Expand All @@ -204,51 +128,35 @@ class DiagonalStack(LinearOperator):

def __init__(
self,
ops: List[LinearOperator],
allow_input_collapse: Optional[bool] = True,
allow_output_collapse: Optional[bool] = True,
ops: Sequence[LinearOperator],
collapse_input: Optional[bool] = True,
collapse_output: Optional[bool] = True,
jit: bool = True,
**kwargs,
):
"""
Args:
op: Operators to form into a block matrix.
allow_input_collapse: If ``True``, inputs are expected to be
ops: Operators to stack.
collapse_input: If ``True``, inputs are expected to be
stacked along the first dimension when possible.
allow_output_collapse: If ``True``, the output will be
collapse_output: If ``True``, the output will be
stacked along the first dimension when possible.
jit: see `jit` in :class:`LinearOperator`.
jit: See `jit` in :class:`LinearOperator`.
"""
self.ops = ops

input_shape, self.collapse_input = collapse_shapes(
tuple(op.input_shape for op in ops),
allow_input_collapse,
)

output_shape, self.collapse_output = collapse_shapes(
tuple(op.output_shape for op in ops),
allow_output_collapse,
)
if not all(isinstance(op, LinearOperator) for op in ops):
raise TypeError("All elements of ops must be of type LinearOperator.")

super().__init__(
input_shape=input_shape,
output_shape=output_shape,
input_dtype=ops[0].input_dtype,
output_dtype=ops[0].output_dtype,
ops=ops,
collapse_input=collapse_input,
collapse_output=collapse_output,
jit=jit,
**kwargs,
)

def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
result = tuple(op @ x_n for op, x_n in zip(self.ops, x))
if self.collapse_output:
return snp.stack(result)
return snp.blockarray(result)

def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type: ignore
result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y))
result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y)) # type: ignore
if self.collapse_input:
return snp.stack(result)
return snp.blockarray(result)
Loading

0 comments on commit 5c83511

Please sign in to comment.