Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding psnrb #1421

Merged
merged 50 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
f487bad
Adding psnrb
soma2000-lang Jan 1, 2023
c4e3fe3
Adding the changes suggested
soma2000-lang Jan 4, 2023
b0e963c
Adding the changes suggested
soma2000-lang Jan 4, 2023
f85d3c8
changes
soma2000-lang Jan 4, 2023
63df2b7
changes
soma2000-lang Jan 4, 2023
ed1a6cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2023
0b2b9dc
imports
Borda Jan 6, 2023
fb450d4
cls name
Borda Jan 6, 2023
64edfaf
Merge branch 'master' into psnrb
Borda Jan 6, 2023
87142b8
precommit
Borda Jan 6, 2023
5a91e17
Merge branch 'psnrb' of https://github.com/soma2000-lang/metrics into…
Borda Jan 6, 2023
9b4207c
fix changelog
SkafteNicki Jan 24, 2023
3c805b2
remove unwanted files
SkafteNicki Jan 24, 2023
32bd8b3
rename file
SkafteNicki Jan 24, 2023
59c7144
fix docstring
SkafteNicki Jan 24, 2023
1f10784
Merge branch 'master' into psnrb
SkafteNicki Jan 24, 2023
af0ea3f
Merge branch 'master' into psnrb
Borda Jan 30, 2023
303d9ba
small fixes
SkafteNicki Feb 3, 2023
8678087
Merge branch 'master' into psnrb
Borda Feb 6, 2023
0a7944b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
c812d8b
doctest
Borda Feb 6, 2023
530b6a4
Apply suggestions from code review
Borda Feb 6, 2023
bed471a
Merge branch 'master' into psnrb
Borda Feb 7, 2023
1378054
Merge branch 'master' into psnrb
Borda Feb 17, 2023
c45b70f
Merge branch 'master' into psnrb
Borda Feb 18, 2023
3b550cf
Merge branch 'master' into psnrb
Borda Feb 22, 2023
97fe8bf
Merge branch 'psnrb' of https://github.com/soma2000-lang/metrics into…
SkafteNicki Feb 24, 2023
d085733
Merge branch 'master' into psnrb
SkafteNicki Feb 24, 2023
317756f
fixes
SkafteNicki Feb 24, 2023
2e487a4
Merge branch 'master' into psnrb
Borda Feb 27, 2023
a990ce6
Merge branch 'master' into psnrb
soma2000-lang Feb 27, 2023
bd84dc0
Merge branch 'master' into psnrb
Borda Feb 28, 2023
5f7d496
Merge branch 'master' into psnrb
Borda Mar 6, 2023
e158f41
Merge branch 'master' into psnrb
Borda Mar 6, 2023
d6de718
Merge branch 'master' into psnrb
Borda Mar 21, 2023
cc7eab5
Merge branch 'master' into psnrb
Borda Mar 31, 2023
53ce7ec
Merge branch 'psnrb' of https://github.com/soma2000-lang/metrics into…
SkafteNicki Apr 14, 2023
5012c5c
merge master
SkafteNicki Apr 14, 2023
4133e75
fix changelog
SkafteNicki Apr 14, 2023
cad1cb3
fix
SkafteNicki Apr 14, 2023
e89623b
fix implementation and tests
SkafteNicki Apr 14, 2023
ebb2bfe
Merge branch 'master' into psnrb
SkafteNicki Apr 15, 2023
4ecde1d
docs: Import through torchmetrics.image
stancld Apr 15, 2023
fac071f
Apply suggestions from code review
stancld Apr 15, 2023
8393c5e
Update regex match accordingly in tests/unittests/image/test_psnrb.py
stancld Apr 16, 2023
09818dd
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
489b06e
Merge branch 'master' into psnrb
Borda Apr 17, 2023
8075256
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
36dfff0
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
8181e61
Merge branch 'master' into psnrb
mergify[bot] Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Binary group fairness metrics to classification package ([#1404](https://github.com/Lightning-AI/metrics/pull/1404))


- Added new detection metric `PanopticQuality` ([#929](https://github.com/PyTorchLightning/metrics/pull/929))


- Added `MinkowskiDistance` to regression package ([#1362](https://github.com/Lightning-AI/metrics/pull/1362))


Expand All @@ -65,6 +62,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
)


- Added `PSNRB` metric ([#1421](https://github.com/Lightning-AI/metrics/pull/1421))


- Added `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))


Expand Down
23 changes: 23 additions & 0 deletions docs/source/image/peak_signal_to_noise_with_block.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Peak Signal To Noise Ratio With Blocked Effect
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

##############################################
Peak Signal To Noise Ratio With Blocked Effect
##############################################

Module Interface
________________

.. autoclass:: torchmetrics.image.PeakSignalNoiseRatioWithBlockedEffect
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.peak_signal_noise_ratio_with_blocked_effect
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Panoptic Quality: https://arxiv.org/abs/1801.00868
.. _torchmetrics mAP example: https://github.com/Lightning-AI/metrics/blob/master/examples/detection_map.py
.. _Peak Signal to Noise Ratio With Blocked Effect: https://ieeexplore.ieee.org/abstract/document/5535179
.. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance
.. _Demographic parity: http://www.fairmlbook.org/
.. _Equal opportunity: https://proceedings.neurips.cc/paper/2016/hash/9d2682367c3935defcb1f9e247a97c0d-Abstract.html
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
Expand All @@ -30,6 +31,7 @@
"error_relative_global_dimensionless_synthesis",
"image_gradients",
"peak_signal_noise_ratio",
"peak_signal_noise_ratio_with_blocked_effect",
"relative_average_spectral_error",
"root_mean_squared_error_using_sliding_window",
"spectral_angle_mapper",
Expand Down
133 changes: 133 additions & 0 deletions src/torchmetrics/functional/image/psnrb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union

import torch
from torch import Tensor, tensor


def _compute_bef(x: Tensor, block_size: int = 8) -> Tensor:
"""Compute block effect.

Args:
x: input image
block_size: integer indication the block size

Returns:
Computed block effect

Raises:
ValueError:
If the image is not a grayscale image

"""
(
_,
channels,
height,
width,
) = x.shape
if channels > 1:
raise ValueError(f"`psnrb` metric expects grayscale images, but got images with {channels} channels.")

h = torch.arange(width - 1)
h_b = torch.tensor(range(block_size - 1, width - 1, block_size))
h_bc = torch.tensor(list(set(h.tolist()).symmetric_difference(h_b.tolist())))

v = torch.arange(height - 1)
v_b = torch.tensor(range(block_size - 1, height - 1, block_size))
v_bc = torch.tensor(list(set(v.tolist()).symmetric_difference(v_b.tolist())))

d_b = (x[:, :, :, h_b] - x[:, :, :, h_b + 1]).pow(2.0).sum()
d_bc = (x[:, :, :, h_bc] - x[:, :, :, h_bc + 1]).pow(2.0).sum()
d_b += (x[:, :, v_b, :] - x[:, :, v_b + 1, :]).pow(2.0).sum()
d_bc += (x[:, :, v_bc, :] - x[:, :, v_bc + 1, :]).pow(2.0).sum()

n_hb = height * (width / block_size) - 1
n_hbc = (height * (width - 1)) - n_hb
n_vb = width * (height / block_size) - 1
n_vbc = (width * (height - 1)) - n_vb
d_b /= n_hb + n_vb
d_bc /= n_hbc + n_vbc
t = math.log2(block_size) / math.log2(min(height, width)) if d_b > d_bc else 0
return t * (d_b - d_bc)


def _psnrb_compute(
sum_squared_error: Tensor,
bef: Tensor,
n_obs: Tensor,
data_range: Tensor,
) -> Tensor:
"""Computes peak signal-to-noise ratio.

Args:
sum_squared_error: Sum of square of errors over all observations
bef: block effect
n_obs: Number of predictions or observations
data_range: the range of the data. If None, it is determined from the data (max - min).
"""
sum_squared_error = sum_squared_error / n_obs + bef
if data_range > 2:
return 10 * torch.log10(data_range**2 / sum_squared_error)
return 10 * torch.log10(1.0 / sum_squared_error)


def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> Tuple[Tensor, Tensor, Tensor]:
"""Updates and returns variables required to compute peak signal-to-noise ratio.

Args:
preds: Predicted tensor
target: Ground truth tensor
block_size: Integer indication the block size
"""
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
n_obs = tensor(target.numel(), device=target.device)
bef = _compute_bef(preds, block_size=block_size)
return sum_squared_error, bef, n_obs


def peak_signal_noise_ratio_with_blocked_effect(
preds: Tensor,
target: Tensor,
block_size: int = 8,
) -> Tensor:
r"""Computes `Peak Signal to Noise Ratio With Blocked Effect` (PSNRB) metrics.

.. math::
\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)

Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function.

Args:
preds: estimated signal
target: groun truth signal
block_size: integer indication the block size

Return:
Tensor with PSNRB score

Example:
>>> import torch
>>> from torchmetrics.functional.image import peak_signal_noise_ratio_with_blocked_effect
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(1, 1, 28, 28)
>>> target = torch.rand(1, 1, 28, 28)
>>> peak_signal_noise_ratio_with_blocked_effect(preds, target)
tensor(7.8402)
"""
data_range = target.max() - target.min()
sum_squared_error, bef, n_obs = _psnrb_update(preds, target, block_size=block_size)
return _psnrb_compute(sum_squared_error, bef, n_obs, data_range)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.psnr import PeakSignalNoiseRatio
from torchmetrics.image.psnrb import PeakSignalNoiseRatioWithBlockedEffect
from torchmetrics.image.rase import RelativeAverageSpectralError
from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow
from torchmetrics.image.sam import SpectralAngleMapper
Expand All @@ -26,6 +27,7 @@
"SpectralDistortionIndex",
"ErrorRelativeGlobalDimensionlessSynthesis",
"PeakSignalNoiseRatio",
"PeakSignalNoiseRatioWithBlockedEffect",
"RelativeAverageSpectralError",
"RootMeanSquaredErrorUsingSlidingWindow",
"SpectralAngleMapper",
Expand Down
139 changes: 139 additions & 0 deletions src/torchmetrics/image/psnrb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.image.psnrb import _psnrb_compute, _psnrb_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PeakSignalNoiseRatioWithBlockedEffect.plot"]


class PeakSignalNoiseRatioWithBlockedEffect(Metric):
r"""Computes `Peak Signal to Noise Ratio With Blocked Effect`_ (PSNRB).

.. math::
\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)

Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function. This metric is a modified version of PSNR that
better supports evaluation of images with blocked artifacts, that oftens occur in compressed images.

.. note::
Metric only supports grayscale images. If you have RGB images, please convert them to grayscale first.

As input to ``forward`` and ``update`` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,1,H,W)``
- ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,1,H,W)``

As output of `forward` and `compute` the metric returns the following output

- ``psnrb`` (:class:`~torch.Tensor`): float scalar tensor with aggregated PSNRB value

Args:
block_size: integer indication the block size
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(2, 1, 10, 10)
>>> target = torch.rand(2, 1, 10, 10)
>>> metric(preds, target)
tensor(7.2893)
"""
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False

sum_squared_error: Tensor
total: Tensor
bef: Tensor
data_range: Tensor

def __init__(
self,
block_size: int = 8,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not isinstance(block_size, int) and block_size < 1:
raise ValueError("Argument ``block_size`` should be a positive integer")
self.block_size = block_size

self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.add_state("bef", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("data_range", default=tensor(0), dist_reduce_fx="max")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, bef, n_obs = _psnrb_update(preds, target, block_size=self.block_size)
self.sum_squared_error += sum_squared_error
self.bef += bef
self.total += n_obs
self.data_range = torch.maximum(self.data_range, torch.max(target) - torch.min(target))

def compute(self) -> Tensor:
"""Compute peak signal-to-noise ratio over state."""
return _psnrb_compute(self.sum_squared_error, self.bef, self.total, self.data_range)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> metric.update(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
Loading