Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spatial Correlation Coefficient (SCC) metric #2248

Merged
merged 31 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c3649ca
SpatialCorrelationCoefficient functionality and module added.
HoseinAkbarzadeh Nov 29, 2023
be3da87
tests for spatial correlation coefficient (scc) added.
HoseinAkbarzadeh Nov 29, 2023
d678f2a
documentation for spatial correlation coefficient (scc) updated.
HoseinAkbarzadeh Nov 29, 2023
952233b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2023
48164a4
Apply suggestions from code review
Borda Nov 30, 2023
2637b23
scc functional docstrings added. _hp_2d_laplacian function updated
HoseinAkbarzadeh Nov 30, 2023
4a0058a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
0668236
required failed checks resolved.
HoseinAkbarzadeh Nov 30, 2023
f231192
fixing merge conflict
HoseinAkbarzadeh Nov 30, 2023
24bfc77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
33b6d90
fixing the variable name mistake.
HoseinAkbarzadeh Nov 30, 2023
f78db34
merge conflict resovled.
HoseinAkbarzadeh Nov 30, 2023
52bdb76
Merge branch 'master' into master
SkafteNicki Nov 30, 2023
f5a874f
fixed even window size bug. changed atol to 1e-8. added None reductio…
HoseinAkbarzadeh Dec 4, 2023
a9c25ee
added new tests for scc functional interface'
HoseinAkbarzadeh Dec 4, 2023
2d9300b
resolving failed mypy checks.
HoseinAkbarzadeh Dec 4, 2023
58b3935
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
5a501ed
resolved long docstring line
HoseinAkbarzadeh Dec 4, 2023
7df880f
merge fix
HoseinAkbarzadeh Dec 4, 2023
38c9f0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
650d227
fixed example bug in docstring
HoseinAkbarzadeh Dec 5, 2023
7605971
fixing merge conflict
HoseinAkbarzadeh Dec 5, 2023
3e577ee
Merge branch 'master' into master
HoseinAkbarzadeh Dec 5, 2023
f423c0c
Merge branch 'master' into master
SkafteNicki Dec 20, 2023
45b6af4
Update src/torchmetrics/functional/image/scc.py
SkafteNicki Dec 20, 2023
6c9308f
changelog
SkafteNicki Dec 20, 2023
9a6f888
Merge branch 'master' into master
mergify[bot] Dec 21, 2023
b5b9702
Merge branch 'master' into HoseinAkbarzadeh/master
Borda Dec 21, 2023
3426755
Apply suggestions from code review
Borda Dec 21, 2023
b4892d5
link
Borda Dec 21, 2023
5e18e82
Merge branch 'master' into master
SkafteNicki Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `CriticalSuccessIndex` metric to image subpackage ([#2257](https://github.com/Lightning-AI/torchmetrics/pull/2257))


- Added `Spatial Correlation Coefficient` to image subpackage ([#2248](https://github.com/Lightning-AI/torchmetrics/pull/2248))


### Changed

- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))
Expand Down
21 changes: 21 additions & 0 deletions docs/source/image/spatial_correlation_coefficient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Spatial Correlation Coefficient (SCC)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#####################################
Spatial Correlation Coefficient (SCC)
#####################################

Module Interface
________________

.. autoclass:: torchmetrics.image.SpatialCorrelationCoefficient
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.spatial_correlation_coefficient
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@
.. _FLORES-101: https://arxiv.org/abs/2106.03193
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.tandfonline.com/doi/abs/10.1080/014311698215973
Borda marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.scc import spatial_correlation_coefficient
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand All @@ -47,4 +48,5 @@
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
"critical_success_index",
"spatial_correlation_coefficient",
]
221 changes: 221 additions & 0 deletions src/torchmetrics/functional/image/scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union

import torch
from torch import Tensor, tensor
from torch.nn.functional import conv2d, pad
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Update and returns variables required to compute Spatial Correlation Coefficient.

Args:
preds: Predicted tensor
target: Ground truth tensor
hp_filter: High-pass filter tensor
window_size: Local window size integer

Return:
Tuple of (preds, target, hp_filter) tensors

Raises:
ValueError:
If ``preds`` and ``target`` have different number of channels
If ``preds`` and ``target`` have different shapes
If ``preds`` and ``target`` have invalid shapes
If ``window_size`` is not a positive integer
If ``window_size`` is greater than the size of the image

"""
if preds.dtype != target.dtype:
target = target.to(preds.dtype)
_check_same_shape(preds, target)
if preds.ndim not in (3, 4):
raise ValueError(
"Expected `preds` and `target` to have batch of colored images with BxCxHxW shape"
" or batch of grayscale images of BxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)

if len(preds.shape) == 3:
preds = preds.unsqueeze(1)
target = target.unsqueeze(1)

if not window_size > 0:
raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.")

if window_size > preds.size(2) or window_size > preds.size(3):
raise ValueError(
f"Expected `window_size` to be less than or equal to the size of the image."
f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}."
)

preds = preds.to(torch.float32)
target = target.to(torch.float32)
hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device)
return preds, target, hp_filter


def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor:
"""Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a)."""
if isinstance(pad, int):
pad = (pad, pad, pad, pad)
if len(pad) != 4:
raise ValueError(f"Expected padding to have length 4, but got {len(pad)}")

left_pad = input_img[:, :, :, 0 : pad[0]].flip(dims=[3])
right_pad = input_img[:, :, :, -pad[1] :].flip(dims=[3])
padded = torch.cat([left_pad, input_img, right_pad], dim=3)

top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2])
bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2])
return torch.cat([top_pad, padded, bottom_pad], dim=2)


def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor:
"""Applies 2D signal convolution to the input tensor with the given kernel."""
left_padding = int(math.floor((kernel.size(3) - 1) / 2))
right_padding = int(math.ceil((kernel.size(3) - 1) / 2))
top_padding = int(math.floor((kernel.size(2) - 1) / 2))
bottom_padding = int(math.ceil((kernel.size(2) - 1) / 2))

padded = _symmetric_reflect_pad_2d(input_img, pad=(left_padding, right_padding, top_padding, bottom_padding))
kernel = kernel.flip([2, 3])
return conv2d(padded, kernel, stride=1, padding=0)


def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor:
"""Applies 2-D Laplace filter to the input tensor with the given high pass filter."""
return _signal_convolve_2d(input_img, kernel) * 2.0


def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Computes local variance and covariance of the input tensors."""
# This code is inspired by
# https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.

left_padding = int(math.ceil((window.size(3) - 1) / 2))
right_padding = int(math.floor((window.size(3) - 1) / 2))

preds = pad(preds, (left_padding, right_padding, left_padding, right_padding))
target = pad(target, (left_padding, right_padding, left_padding, right_padding))

preds_mean = conv2d(preds, window, stride=1, padding=0)
target_mean = conv2d(target, window, stride=1, padding=0)

preds_var = conv2d(preds**2, window, stride=1, padding=0) - preds_mean**2
target_var = conv2d(target**2, window, stride=1, padding=0) - target_mean**2
target_preds_cov = conv2d(target * preds, window, stride=1, padding=0) - target_mean * preds_mean

return preds_var, target_var, target_preds_cov


def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tensor:
"""Computes per channel Spatial Correlation Coefficient.

Args:
preds: estimated image of Bx1xHxW shape.
target: ground truth image of Bx1xHxW shape.
hp_filter: 2D high-pass filter.
window_size: size of window for local mean calculation.

Return:
Tensor with Spatial Correlation Coefficient score

"""
dtype = preds.dtype
device = preds.device

# This code is inspired by
# https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.

window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2)

preds_hp = _hp_2d_laplacian(preds, hp_filter)
target_hp = _hp_2d_laplacian(target, hp_filter)

preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window)

preds_var[preds_var < 0] = 0
target_var[target_var < 0] = 0

den = torch.sqrt(target_var) * torch.sqrt(preds_var)
idx = den == 0
den[den == 0] = 1
scc = target_preds_cov / den
scc[idx] = 0
return scc


def spatial_correlation_coefficient(
preds: Tensor,
target: Tensor,
hp_filter: Optional[Tensor] = None,
window_size: int = 8,
reduction: Optional[Literal["mean", "none", None]] = "mean",
) -> Tensor:
"""Compute Spatial Correlation Coefficient (SCC_).

Args:
preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]])
window_size: Local window size integer. default: 8,
reduction: Reduction method for output tensor. If ``None`` or ``"none"``,
returns a tensor with the per sample results. default: ``"mean"``.

Return:
Tensor with scc score

Example:
>>> import torch
>>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc
>>> _ = torch.manual_seed(42)
>>> x = torch.randn(5, 3, 16, 16)
>>> scc(x, x)
tensor(1.)
>>> x = torch.randn(5, 16, 16)
>>> scc(x, x)
tensor(1.)
>>> x = torch.randn(5, 3, 16, 16)
>>> y = torch.randn(5, 3, 16, 16)
>>> scc(x, y, reduction="none")
tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170])

"""
if hp_filter is None:
hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
if reduction is None:
reduction = "none"
if reduction not in ("mean", "none"):
raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}")
preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size)

per_channel = [
_scc_per_channel_compute(
preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size
)
for i in range(preds.size(1))
]
if reduction == "none":
return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3])
if reduction == "mean":
return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean")
return None
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.image.rase import RelativeAverageSpectralError
from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow
from torchmetrics.image.sam import SpectralAngleMapper
from torchmetrics.image.scc import SpatialCorrelationCoefficient
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure
from torchmetrics.image.tv import TotalVariation
from torchmetrics.image.uqi import UniversalImageQualityIndex
Expand All @@ -44,6 +45,7 @@
"VisualInformationFidelity",
"TotalVariation",
"CriticalSuccessIndex",
"SpatialCorrelationCoefficient",
]

if _TORCH_FIDELITY_AVAILABLE:
Expand Down
84 changes: 84 additions & 0 deletions src/torchmetrics/image/scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import torch
from torch import Tensor, tensor

from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute
from torchmetrics.functional.image.scc import _scc_update
from torchmetrics.metric import Metric


class SpatialCorrelationCoefficient(Metric):
"""Compute Spatial Correlation Coefficient (SCC_).

As input to ``forward`` and ``update`` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` or ``(N,H,W)``.
- ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` or ``(N,H,W)``.

As output of `forward` and `compute` the metric returns the following output

- ``scc`` (:class:`~torch.Tensor`): Tensor with scc score

Args:
hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]).
window_size: Local window size integer. default: 8.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC
>>> preds = torch.randn([32, 3, 64, 64])
>>> target = torch.randn([32, 3, 64, 64])
>>> scc = SCC()
>>> scc(preds, target)
tensor(0.0023)

"""

is_differentiable = True
higher_is_better = True
full_state_update = False

scc_score: Tensor
total: Tensor

def __init__(self, high_pass_filter: Optional[Tensor] = None, window_size: int = 8, **kwargs: Any) -> None:
super().__init__(**kwargs)

if high_pass_filter is None:
high_pass_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])

self.hp_filter = high_pass_filter
self.ws = window_size

self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws)
scc_per_channel = [
_scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws)
for i in range(preds.size(1))
]
self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1, 2, 3]))
self.total += preds.size(0)

def compute(self) -> Tensor:
"""Compute the VIF score based on inputs passed in to ``update`` previously."""
return self.scc_score / self.total
Loading
Loading