diff --git a/CHANGELOG.md b/CHANGELOG.md index 075a33c946b..43347fe3efc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `CriticalSuccessIndex` metric to image subpackage ([#2257](https://github.com/Lightning-AI/torchmetrics/pull/2257)) +- Added `Spatial Correlation Coefficient` to image subpackage ([#2248](https://github.com/Lightning-AI/torchmetrics/pull/2248)) + + ### Changed - Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145)) diff --git a/docs/source/image/spatial_correlation_coefficient.rst b/docs/source/image/spatial_correlation_coefficient.rst new file mode 100644 index 00000000000..02ed96fd107 --- /dev/null +++ b/docs/source/image/spatial_correlation_coefficient.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Spatial Correlation Coefficient (SCC) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +##################################### +Spatial Correlation Coefficient (SCC) +##################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.image.SpatialCorrelationCoefficient + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.image.spatial_correlation_coefficient diff --git a/docs/source/links.rst b/docs/source/links.rst index eafa1a5ffa7..b2e25dabd08 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -168,3 +168,4 @@ .. _FLORES-101: https://arxiv.org/abs/2106.03193 .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html +.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index ae5e3ef745f..46008539f45 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.functional.image.rase import relative_average_spectral_error from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window from torchmetrics.functional.image.sam import spectral_angle_mapper +from torchmetrics.functional.image.scc import spatial_correlation_coefficient from torchmetrics.functional.image.ssim import ( multiscale_structural_similarity_index_measure, structural_similarity_index_measure, @@ -49,4 +50,5 @@ "learned_perceptual_image_patch_similarity", "perceptual_path_length", "critical_success_index", + "spatial_correlation_coefficient", ] diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py new file mode 100644 index 00000000000..167ddbd37b5 --- /dev/null +++ b/src/torchmetrics/functional/image/scc.py @@ -0,0 +1,221 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, tensor +from torch.nn.functional import conv2d, pad +from typing_extensions import Literal + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.distributed import reduce + + +def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]: + """Update and returns variables required to compute Spatial Correlation Coefficient. + + Args: + preds: Predicted tensor + target: Ground truth tensor + hp_filter: High-pass filter tensor + window_size: Local window size integer + + Return: + Tuple of (preds, target, hp_filter) tensors + + Raises: + ValueError: + If ``preds`` and ``target`` have different number of channels + If ``preds`` and ``target`` have different shapes + If ``preds`` and ``target`` have invalid shapes + If ``window_size`` is not a positive integer + If ``window_size`` is greater than the size of the image + + """ + if preds.dtype != target.dtype: + target = target.to(preds.dtype) + _check_same_shape(preds, target) + if preds.ndim not in (3, 4): + raise ValueError( + "Expected `preds` and `target` to have batch of colored images with BxCxHxW shape" + " or batch of grayscale images of BxHxW shape." + f" Got preds: {preds.shape} and target: {target.shape}." + ) + + if len(preds.shape) == 3: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + + if not window_size > 0: + raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.") + + if window_size > preds.size(2) or window_size > preds.size(3): + raise ValueError( + f"Expected `window_size` to be less than or equal to the size of the image." + f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}." + ) + + preds = preds.to(torch.float32) + target = target.to(torch.float32) + hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device) + return preds, target, hp_filter + + +def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: + """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a).""" + if isinstance(pad, int): + pad = (pad, pad, pad, pad) + if len(pad) != 4: + raise ValueError(f"Expected padding to have length 4, but got {len(pad)}") + + left_pad = input_img[:, :, :, 0 : pad[0]].flip(dims=[3]) + right_pad = input_img[:, :, :, -pad[1] :].flip(dims=[3]) + padded = torch.cat([left_pad, input_img, right_pad], dim=3) + + top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2]) + bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2]) + return torch.cat([top_pad, padded, bottom_pad], dim=2) + + +def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor: + """Applies 2D signal convolution to the input tensor with the given kernel.""" + left_padding = int(math.floor((kernel.size(3) - 1) / 2)) + right_padding = int(math.ceil((kernel.size(3) - 1) / 2)) + top_padding = int(math.floor((kernel.size(2) - 1) / 2)) + bottom_padding = int(math.ceil((kernel.size(2) - 1) / 2)) + + padded = _symmetric_reflect_pad_2d(input_img, pad=(left_padding, right_padding, top_padding, bottom_padding)) + kernel = kernel.flip([2, 3]) + return conv2d(padded, kernel, stride=1, padding=0) + + +def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor: + """Applies 2-D Laplace filter to the input tensor with the given high pass filter.""" + return _signal_convolve_2d(input_img, kernel) * 2.0 + + +def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Computes local variance and covariance of the input tensors.""" + # This code is inspired by + # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. + + left_padding = int(math.ceil((window.size(3) - 1) / 2)) + right_padding = int(math.floor((window.size(3) - 1) / 2)) + + preds = pad(preds, (left_padding, right_padding, left_padding, right_padding)) + target = pad(target, (left_padding, right_padding, left_padding, right_padding)) + + preds_mean = conv2d(preds, window, stride=1, padding=0) + target_mean = conv2d(target, window, stride=1, padding=0) + + preds_var = conv2d(preds**2, window, stride=1, padding=0) - preds_mean**2 + target_var = conv2d(target**2, window, stride=1, padding=0) - target_mean**2 + target_preds_cov = conv2d(target * preds, window, stride=1, padding=0) - target_mean * preds_mean + + return preds_var, target_var, target_preds_cov + + +def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tensor: + """Computes per channel Spatial Correlation Coefficient. + + Args: + preds: estimated image of Bx1xHxW shape. + target: ground truth image of Bx1xHxW shape. + hp_filter: 2D high-pass filter. + window_size: size of window for local mean calculation. + + Return: + Tensor with Spatial Correlation Coefficient score + + """ + dtype = preds.dtype + device = preds.device + + # This code is inspired by + # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. + + window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2) + + preds_hp = _hp_2d_laplacian(preds, hp_filter) + target_hp = _hp_2d_laplacian(target, hp_filter) + + preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window) + + preds_var[preds_var < 0] = 0 + target_var[target_var < 0] = 0 + + den = torch.sqrt(target_var) * torch.sqrt(preds_var) + idx = den == 0 + den[den == 0] = 1 + scc = target_preds_cov / den + scc[idx] = 0 + return scc + + +def spatial_correlation_coefficient( + preds: Tensor, + target: Tensor, + hp_filter: Optional[Tensor] = None, + window_size: int = 8, + reduction: Optional[Literal["mean", "none", None]] = "mean", +) -> Tensor: + """Compute Spatial Correlation Coefficient (SCC_). + + Args: + preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``. + target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``. + hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]) + window_size: Local window size integer. default: 8, + reduction: Reduction method for output tensor. If ``None`` or ``"none"``, + returns a tensor with the per sample results. default: ``"mean"``. + + Return: + Tensor with scc score + + Example: + >>> import torch + >>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc + >>> _ = torch.manual_seed(42) + >>> x = torch.randn(5, 3, 16, 16) + >>> scc(x, x) + tensor(1.) + >>> x = torch.randn(5, 16, 16) + >>> scc(x, x) + tensor(1.) + >>> x = torch.randn(5, 3, 16, 16) + >>> y = torch.randn(5, 3, 16, 16) + >>> scc(x, y, reduction="none") + tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170]) + + """ + if hp_filter is None: + hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) + if reduction is None: + reduction = "none" + if reduction not in ("mean", "none"): + raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}") + preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) + + per_channel = [ + _scc_per_channel_compute( + preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size + ) + for i in range(preds.size(1)) + ] + if reduction == "none": + return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3]) + if reduction == "mean": + return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean") + return None diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index dbd8df587da..4977a2f1b46 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -21,6 +21,7 @@ from torchmetrics.image.rase import RelativeAverageSpectralError from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow from torchmetrics.image.sam import SpectralAngleMapper +from torchmetrics.image.scc import SpatialCorrelationCoefficient from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure from torchmetrics.image.tv import TotalVariation from torchmetrics.image.uqi import UniversalImageQualityIndex @@ -46,6 +47,7 @@ "VisualInformationFidelity", "TotalVariation", "CriticalSuccessIndex", + "SpatialCorrelationCoefficient", ] if _TORCH_FIDELITY_AVAILABLE: diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py new file mode 100644 index 00000000000..15ea2b96ecf --- /dev/null +++ b/src/torchmetrics/image/scc.py @@ -0,0 +1,84 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute +from torchmetrics.functional.image.scc import _scc_update +from torchmetrics.metric import Metric + + +class SpatialCorrelationCoefficient(Metric): + """Compute Spatial Correlation Coefficient (SCC_). + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` or ``(N,H,W)``. + - ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` or ``(N,H,W)``. + + As output of `forward` and `compute` the metric returns the following output + + - ``scc`` (:class:`~torch.Tensor`): Tensor with scc score + + Args: + hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]). + window_size: Local window size integer. default: 8. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC + >>> preds = torch.randn([32, 3, 64, 64]) + >>> target = torch.randn([32, 3, 64, 64]) + >>> scc = SCC() + >>> scc(preds, target) + tensor(0.0023) + + """ + + is_differentiable = True + higher_is_better = True + full_state_update = False + + scc_score: Tensor + total: Tensor + + def __init__(self, high_pass_filter: Optional[Tensor] = None, window_size: int = 8, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if high_pass_filter is None: + high_pass_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) + + self.hp_filter = high_pass_filter + self.ws = window_size + + self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws) + scc_per_channel = [ + _scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws) + for i in range(preds.size(1)) + ] + self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1, 2, 3])) + self.total += preds.size(0) + + def compute(self) -> Tensor: + """Compute the VIF score based on inputs passed in to ``update`` previously.""" + return self.scc_score / self.total diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py new file mode 100644 index 00000000000..6ef1443371c --- /dev/null +++ b/tests/unittests/image/test_scc.py @@ -0,0 +1,98 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch +from sewar.full_ref import scc as sewar_scc +from torchmetrics.functional.image import spatial_correlation_coefficient +from torchmetrics.image import SpatialCorrelationCoefficient + +from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +_inputs = [ + _Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 32, 32), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 32, 32), + ) + for channels in [1, 3] +] +_kernels = [torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])] + + +def _reference_scc(preds, target): + """Reference implementation of scc from sewar.""" + preds = torch.movedim(preds, 1, -1) + target = torch.movedim(target, 1, -1) + preds = preds.cpu().numpy() + target = target.cpu().numpy() + hp_filter = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) + window_size = 8 + scc = [ + sewar_scc(GT=target[batch], P=preds[batch], win=hp_filter, ws=window_size) for batch in range(preds.shape[0]) + ] + return np.mean(scc) + + +def _wrapped_reference_scc(win, ws, reduction): + """Wrapper around reference implementation of scc from sewar.""" + + def _wrapped(preds, target): + preds = torch.movedim(preds, 1, -1) + target = torch.movedim(target, 1, -1) + preds = preds.cpu().numpy() + target = target.cpu().numpy() + scc = [sewar_scc(GT=target[batch], P=preds[batch], win=win, ws=ws) for batch in range(preds.shape[0])] + if reduction == "mean": + return np.mean(scc) + if reduction == "none": + return scc + return None + + return _wrapped + + +@pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) +class TestSpatialCorrelationCoefficient(MetricTester): + """Tests for SpatialCorrelationCoefficient metric.""" + + atol = 1e-8 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_scc(self, preds, target, ddp): + """Test SpatialCorrelationCoefficient class usage.""" + self.run_class_metric_test( + ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_scc + ) + + @pytest.mark.parametrize("hp_filter", _kernels) + @pytest.mark.parametrize("window_size", [8, 11]) + @pytest.mark.parametrize("reduction", ["mean", "none"]) + def test_scc_functional(self, preds, target, hp_filter, window_size, reduction): + """Test SpatialCorrelationCoefficient functional usage.""" + self.run_functional_metric_test( + preds, + target, + metric_functional=spatial_correlation_coefficient, + reference_metric=_wrapped_reference_scc(hp_filter, window_size, reduction), + metric_args={ + "hp_filter": hp_filter, + "window_size": window_size, + "reduction": reduction, + }, + )