From 1b7f583e16249d5a3094c74103b60024ab7111d1 Mon Sep 17 00:00:00 2001 From: Salman Naqvi Date: Thu, 5 Oct 2023 15:02:48 -0700 Subject: [PATCH] Added a new Functional for TV Norm implementing its proximal operator using the fast subiteration free algorithm proposed by Kamilov, 2016 --- scico/functional/__init__.py | 2 + scico/functional/_norm.py | 104 +++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/scico/functional/__init__.py b/scico/functional/__init__.py index 53d426067..377e29b1a 100644 --- a/scico/functional/__init__.py +++ b/scico/functional/__init__.py @@ -20,6 +20,7 @@ L21Norm, NuclearNorm, L1MinusL2Norm, + TV2DNorm, ) from ._indicator import NonNegativeIndicator, L2BallIndicator from ._denoiser import BM3D, BM4D, DnCNN @@ -46,6 +47,7 @@ "BM3D", "BM4D", "DnCNN", + "TV2DNorm", ] # Imported items in __all__ appear to originate in top-level functional module diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index 332d0500f..22e7e9825 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -15,6 +15,7 @@ from scico.numpy import Array, BlockArray, count_nonzero from scico.numpy.linalg import norm from scico.numpy.util import no_nan_divide +from scico.linop import FiniteDifference from ._functional import Functional @@ -477,3 +478,106 @@ def prox( svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False) svdS = snp.maximum(0, svdS - lam) return svdU @ snp.diag(svdS) @ svdV + + +class TV2DNorm(Functional): + r"""The :math:`\ell_{TV}` norm. + + For a :math:`M \times N` matrix, :math:`\mb{A}`, by default, + + .. math:: + \norm{\mb{A}}_{TV} = \sum_{n=1}^N \sum_{m=1}^M + \abs{\nabla{A}_{m,n}} \;. + + This norm currently only has proximal operator defined only for + 2 dimensional data. + + For `BlockArray` inputs, the :math:`\ell_{TV}` norm follows the + reduction rules described in :class:`BlockArray`. + + A typical use case is computing the anisotropic total variation norm. + """ + + has_eval = True + has_prox = True + + def __init__(self, dims, tau: float = 1.0): + r""" + Args: + tau: Parameter :math:`\tau` in the norm definition. + """ + self.dims = dims + self.tau = tau + + def __call__(self, x: Union[Array, BlockArray]) -> float: + r"""Return the :math:`\ell_{TV}` norm of an array.""" + y = 0 + gradOp = FiniteDifference(self.dims, input_dtype=x.dtype, circular=True) + grads = gradOp @ x + for g in grads: + y += snp.abs(g) + return self.tau * snp.sum(y) + + def prox( + self, x: Union[Array, BlockArray], lam: float = 1.0, **kwargs + ) -> Union[Array, BlockArray]: + r"""Proximal operator of the :math:`\ell_{TV}` norm. + + Evaluate proximal operator of the TV norm + :cite:`tip-2016-kamilov`. + + Args: + v: Input array :math:`\mb{v}`. + lam: Proximal parameter :math:`\lam`. + kwargs: Additional arguments that may be used by derived + classes. + """ + D = 2 + K = 2*D + thresh = snp.sqrt(2) * K * self.tau * lam + + y = snp.zeros_like(x) + for ax in range(2): + y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=False), thresh), axis=ax, shift=False)) + y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=True), thresh), axis=ax, shift=True)) + y = y.at[:].divide(K) + + return y + + def ht2(self, x, axis, shift): + s = x.shape + w = snp.zeros(s) + C = 1 / snp.sqrt(2) + if shift: + x = snp.roll(x, -1, axis=axis) + + m = s[axis] // 2 + if not axis: + w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :])) + w = w.at[m:, :].set(C * (x[1::2, :] - x[::2, :])) + else: + w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2])) + w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2])) + return w + + def iht2(self, w, axis, shift): + s = snp.shape(w) + y = snp.zeros(s) + C = 1 / snp.sqrt(2) + m = s[axis] // 2 + if not axis: + y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :])) + y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :])) + else: + y = y.at[:, ::2].set(C * (w[:, :m] - w[:, m:])) + y = y.at[:, 1::2].set(C * (w[:, :m] + w[:, m:])) + + if shift: + y = snp.roll(y, 1, axis) + + return y + + def shrink(self, x, tau): + threshed = snp.maximum(snp.abs(x)-tau, 0) + threshed = threshed.at[:].multiply(snp.sign(x)) + return threshed \ No newline at end of file