Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a new Functional for TV Norm #456

Merged
merged 30 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1b7f583
Added a new Functional for TV Norm implementing its proximal operator…
Oct 5, 2023
9d1d73a
added checks for input shape in TV2DNorm
Oct 5, 2023
877df4c
fixed lint errors, changed required argument to default in TV2DNorm, …
Oct 5, 2023
4947382
some unsaved changes from last commit
Oct 5, 2023
4375ddf
some unsaved changes from last commit
Oct 5, 2023
71fc636
newline at end of file error
Oct 5, 2023
c62da28
sort imports lint error
Oct 5, 2023
98ce989
removed the default shape parameter from TV2DNorm
Oct 6, 2023
096d1c9
Some docs edits
bwohlberg Oct 9, 2023
c2e1de5
Disable BlockArray tests on TV2DNorm
bwohlberg Oct 9, 2023
3a0cdb0
Fix black formatting
bwohlberg Oct 9, 2023
605d11b
updated the TV norm logic to apply shrinkage to only the difference o…
Oct 11, 2023
a7e82ba
Merge branch 'main' into tv-norm
bwohlberg Oct 31, 2023
c8efe90
Implementation supporting arbitrary dimensional inputs
bwohlberg Nov 2, 2023
b5e8fc9
Merge branch 'main' into tv-norm
bwohlberg Nov 3, 2023
2654882
Merge branch 'tv-norm' into tv-norm-alt-ver
bwohlberg Nov 3, 2023
ec8686e
Add a test
bwohlberg Nov 3, 2023
4f2f189
Minor changes
bwohlberg Nov 3, 2023
b7427f7
New implementation of TV norm and approximage prox
bwohlberg Nov 3, 2023
c0c9633
Clean up
bwohlberg Nov 3, 2023
f251c60
Typo fix
bwohlberg Nov 3, 2023
feb4b77
Minor change
bwohlberg Nov 4, 2023
7fe98b9
Add change log entry
bwohlberg Nov 4, 2023
2963523
Merge pull request #2 from shnaqvi/tv-norm-alt-ver
shnaqvi Nov 5, 2023
ded11f8
Resolve typing errors
bwohlberg Nov 5, 2023
c760e45
Resolve some oversights and issues arising when 64 bit floats enabled
bwohlberg Nov 5, 2023
dafd626
Standardise code formatting
bwohlberg Nov 5, 2023
2949931
Standardise code formatting
bwohlberg Nov 5, 2023
6a654ec
Standardise code formatting
bwohlberg Nov 5, 2023
3b7f75b
Apply skipped pre-commit
bwohlberg Nov 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• New functional ``functional.AnisotropicTVNorm`` with proximal operator
approximation.
• New integrated Radon/X-ray transform ``linop.XRayTransform``.
• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
Expand Down
12 changes: 12 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ @Article {jin-2017-unet
doi = {10.1109/TIP.2017.2713099}
}

@Article {kamilov-2016-parallel,
title = {A parallel proximal algorithm for anisotropic total
variation minimization},
author = {Ulugbek S. Kamilov},
journal = {IEEE Transactions on Image Processing},
volume = 26,
number = 2,
pages = {539--548},
year = 2016,
doi = {10.1109/tip.2016.2629449 }
}

@Article {kamilov-2017-plugandplay,
author = {Ulugbek Kamilov and Hassan Mansour and Brendt
Wohlberg},
Expand Down
2 changes: 2 additions & 0 deletions scico/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
NuclearNorm,
L1MinusL2Norm,
)
from ._tvnorm import AnisotropicTVNorm
from ._indicator import NonNegativeIndicator, L2BallIndicator
from ._denoiser import BM3D, BM4D, DnCNN
from ._dist import SetDistance, SquaredSetDistance


__all__ = [
"AnisotropicTVNorm",
"Functional",
"ScaledFunctional",
"SeparableFunctional",
Expand Down
142 changes: 142 additions & 0 deletions scico/functional/_tvnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Anisotropic total variation norm."""

from typing import Optional, Tuple

from scico import numpy as snp
from scico.linop import (
CircularConvolve,
FiniteDifference,
LinearOperator,
VerticalStack,
)
from scico.numpy import Array

from ._functional import Functional
from ._norm import L1Norm


class AnisotropicTVNorm(Functional):
r"""The anisotropic total variation (TV) norm.

The anisotropic total variation (TV) norm computed by

.. code-block:: python

ATV = scico.functional.AnisotropicTVNorm()
x_norm = ATV(x)

is equivalent to

.. code-block:: python

C = linop.FiniteDifference(input_shape=x.shape, circular=True)
L1 = functional.L1Norm()
x_norm = L1(C @ x)

The scaled proximal operator is computed using an approximation that
holds for small scaling parameters :cite:`kamilov-2016-parallel`.
This does not imply that it can only be applied to problems requiring
a small regularization parameter since most proximal algorithms
include an additional algorithm parameter that also plays a role in
the parameter of the proximal operator. For example, in :class:`.PGM`
and :class:`.AcceleratedPGM`, the scaled proximal operator parameter
is the regularization parameter divided by the `L0` algorithm
parameter, and for :class:`.ADMM`, the scaled proximal operator
parameters are the regularization parameters divided by the entries
in the `rho_list` algorithm parameter.
"""

has_eval = True
has_prox = True

def __init__(self, ndims: Optional[int] = None):
r"""
Args:
ndims: Number of (trailing) dimensions of the input over
which to apply the finite difference operator. If
``None``, differences are evaluated along all axes.
"""
self.ndims = ndims
self.h0 = snp.array([1.0, 1.0]) / snp.sqrt(2.0) # lowpass filter
self.h1 = snp.array([1.0, -1.0]) / snp.sqrt(2.0) # highpass filter
self.l1norm = L1Norm()
self.G: Optional[LinearOperator] = None
self.W: Optional[LinearOperator] = None

def __call__(self, x: Array) -> float:
r"""Compute the anisotropic TV norm of an array."""
if self.G is None or self.G.shape[1] != x.shape:
if self.ndims is None:
ndims = x.ndim
else:
ndims = self.ndims
axes = tuple(range(ndims))
self.G = FiniteDifference(
x.shape, input_dtype=x.dtype, axes=axes, circular=True, jit=True
)
return snp.sum(snp.abs(self.G @ x))

@staticmethod
def _shape(idx: int, ndims: int) -> Tuple:
"""Construct a shape tuple.

Construct a tuple of size `ndims` with all unit entries except
for index `idx`, which has a -1 entry.
"""
return (1,) * idx + (-1,) + (1,) * (ndims - idx - 1)

def prox(self, v: Array, lam: float = 1.0, **kwargs) -> Array:
r"""Approximate proximal operator of the isotropic TV norm.

Approximation of the proximal operator of the anisotropic TV norm,
computed via the method described in :cite:`kamilov-2016-parallel`.

Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lam`.
kwargs: Additional arguments that may be used by derived
classes.
"""
if self.ndims is None:
ndims = v.ndim
else:
ndims = self.ndims
K = 2 * ndims

if self.W is None or self.W.shape[1] != v.shape:
h0 = self.h0.astype(v.dtype)
h1 = self.h1.astype(v.dtype)
C0 = VerticalStack( # Stack of lowpass filter operators for each axis
[
CircularConvolve(
h0.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
C1 = VerticalStack( # Stack of highpass filter operators for each axis
[
CircularConvolve(
h1.reshape(AnisotropicTVNorm._shape(k, ndims)),
v.shape,
ndims=self.ndims,
)
for k in range(ndims)
]
)
# single-level shift-invariant Haar transform
self.W = VerticalStack([C0, C1], jit=True)

Wv = self.W @ v
# Apply 𝑙1 shrinkage to highpass component of shift-invariant Haar transform
Wv = Wv.at[1].set(self.l1norm.prox(Wv[1], snp.sqrt(2) * K * lam))
return (1.0 / K) * self.W.T @ Wv
7 changes: 6 additions & 1 deletion scico/test/functional/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from scico import functional
from scico.random import randn

NO_BLOCK_ARRAY = [functional.L21Norm, functional.L1MinusL2Norm, functional.NuclearNorm]
NO_BLOCK_ARRAY = [
functional.L21Norm,
functional.L1MinusL2Norm,
functional.NuclearNorm,
functional.AnisotropicTVNorm,
]
NO_COMPLEX = [functional.NonNegativeIndicator]


Expand Down
40 changes: 40 additions & 0 deletions scico/test/functional/test_tvnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np

import scico.random
from scico import functional, linop, loss, metric
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.optimize.pgm import AcceleratedPGM


def test_tvnorm():

N = 128
g = np.linspace(0, 2 * np.pi, N, dtype=np.float32)
x_gt = np.sin(2 * g)
x_gt[x_gt > 0.5] = 0.5
x_gt[x_gt < -0.5] = -0.5
σ = 0.02
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise

λ = 5e-2
f = loss.SquaredL2Loss(y=y)

C = linop.FiniteDifference(input_shape=x_gt.shape, circular=True)
g = λ * functional.L1Norm()
solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[1e1],
x0=y,
maxiter=50,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
)
x_tvdn = solver.solve()

h = λ * functional.AnisotropicTVNorm()
solver = AcceleratedPGM(f=f, g=h, L0=2e2, x0=y, maxiter=50)
x_approx = solver.solve()

assert metric.snr(x_tvdn, x_approx) > 45
Loading