Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename PESQ #751

Merged
merged 6 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `MatthewsCorrcoef` -> `MatthewsCorrCoef`
* `PearsonCorrcoef` -> `PearsonCorrCoef`
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`
- Renamed audio PESQ metrics: ([#751](https://github.com/PyTorchLightning/metrics/pull/751))
* `functional.audio.pesq` -> `functional.audio.perceptual_evaluation_speech_quality`
* `audio.PESQ` -> `audio.PerceptualEvaluationSpeechQuality`
- Renamed audio SDR metrics: ([#711](https://github.com/PyTorchLightning/metrics/pull/711))
* `functional.sdr` -> `functional.signal_distortion_ratio`
* `functional.si_sdr` -> `functional.scale_invariant_signal_distortion_ratio`
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ Functional metrics
Audio
*****

pesq [func]
~~~~~~~~~~~
perceptual_evaluation_speech_quality [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.audio.pesq.pesq
.. autofunction:: torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality


permutation_invariant_training [func]
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ the metric will be computed over the ``time`` dimension.
>>> snr(preds, target)
tensor(16.1805)

PESQ
~~~~
PerceptualEvaluationSpeechQuality
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.audio.pesq.PESQ
.. autoclass:: torchmetrics.audio.pesq.PerceptualEvaluationSpeechQuality

PermutationInvariantTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
26 changes: 16 additions & 10 deletions tests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio.pesq import PESQ
from torchmetrics.functional.audio.pesq import pesq
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step):
ddp,
preds,
target,
PESQ,
PerceptualEvaluationSpeechQuality,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(fs=fs, mode=mode),
Expand All @@ -92,14 +92,18 @@ def test_pesq_functional(self, preds, target, sk_metric, fs, mode):
self.run_functional_metric_test(
preds,
target,
pesq,
perceptual_evaluation_speech_quality,
sk_metric,
metric_args=dict(fs=fs, mode=mode),
)

def test_pesq_differentiability(self, preds, target, sk_metric, fs, mode):
self.run_differentiability_test(
preds=preds, target=target, metric_module=PESQ, metric_functional=pesq, metric_args=dict(fs=fs, mode=mode)
preds=preds,
target=target,
metric_module=PerceptualEvaluationSpeechQuality,
metric_functional=perceptual_evaluation_speech_quality,
metric_args=dict(fs=fs, mode=mode),
)

@pytest.mark.skipif(
Expand All @@ -113,13 +117,13 @@ def test_pesq_half_gpu(self, preds, target, sk_metric, fs, mode):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=PESQ,
metric_functional=partial(pesq, fs=fs, mode=mode),
metric_module=PerceptualEvaluationSpeechQuality,
metric_functional=partial(perceptual_evaluation_speech_quality, fs=fs, mode=mode),
metric_args=dict(fs=fs, mode=mode),
)


def test_error_on_different_shape(metric_class=PESQ):
def test_error_on_different_shape(metric_class=PerceptualEvaluationSpeechQuality):
metric = metric_class(16000, "nb")
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
Expand All @@ -134,5 +138,7 @@ def test_on_real_audio():

rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav"))
rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav"))
assert pesq(torch.from_numpy(deg), torch.from_numpy(ref), rate, "wb") == 1.0832337141036987
assert pesq(torch.from_numpy(deg), torch.from_numpy(ref), rate, "nb") == 1.6072081327438354
pesq = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "wb")
assert pesq == 1.0832337141036987
pesq = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "nb")
assert pesq == 1.6072081327438354
4 changes: 4 additions & 0 deletions torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401
from torchmetrics.utilities.imports import _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401
52 changes: 44 additions & 8 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
from typing import Any, Callable, Optional

from deprecate import deprecated, void
from torch import Tensor, tensor

from torchmetrics.functional.audio.pesq import pesq
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
from torchmetrics.metric import Metric
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _PESQ_AVAILABLE


class PESQ(Metric):
"""PESQ (Perceptual Evaluation of Speech Quality)
class PerceptualEvaluationSpeechQuality(Metric):
"""Perceptual Evaluation of Speech Quality (PESQ)

This is a wrapper for the pesq package [1]. . Note that input will be moved to `cpu`
to perform the metric calculation.
Expand Down Expand Up @@ -62,15 +64,15 @@ class PESQ(Metric):
If ``mode`` is not either ``"wb"`` or ``"nb"``

Example:
>>> from torchmetrics.audio.pesq import PESQ
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PESQ(8000, 'nb')
>>> nb_pesq = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> nb_pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PESQ(16000, 'wb')
>>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)

Expand Down Expand Up @@ -100,7 +102,7 @@ def __init__(
)
if not _PESQ_AVAILABLE:
raise ModuleNotFoundError(
"PESQ metric requires that pesq is installed."
"PerceptualEvaluationSpeechQuality metric requires that `pesq` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pesq`."
)
if fs not in (8000, 16000):
Expand All @@ -120,11 +122,45 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
pesq_batch = pesq(preds, target, self.fs, self.mode, False).to(self.sum_pesq.device)
pesq_batch = perceptual_evaluation_speech_quality(preds, target, self.fs, self.mode, False).to(
self.sum_pesq.device
)

self.sum_pesq += pesq_batch.sum()
self.total += pesq_batch.numel()

def compute(self) -> Tensor:
"""Computes average PESQ."""
return self.sum_pesq / self.total


class PESQ(PerceptualEvaluationSpeechQuality):
"""Perceptual Evaluation of Speech Quality (PESQ).

.. deprecated:: v0.7
Use :class:`torchmetrics.audio.PerceptualEvaluationSpeechQuality`. Will be removed in v0.8.

Example:
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PESQ(8000, 'nb')
>>> nb_pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PESQ(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)
"""

@deprecated(target=PerceptualEvaluationSpeechQuality, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
void(fs, mode, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
4 changes: 4 additions & 0 deletions torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@
from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401
from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr # noqa: F401
from torchmetrics.utilities.imports import _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality # noqa: F401
32 changes: 28 additions & 4 deletions torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from deprecate import deprecated, void

from torchmetrics.utilities.imports import _PESQ_AVAILABLE

Expand All @@ -22,10 +23,13 @@
import torch
from torch import Tensor

from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.checks import _check_same_shape


def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False) -> Tensor:
def perceptual_evaluation_speech_quality(
preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False
) -> Tensor:
r"""PESQ (Perceptual Evaluation of Speech Quality)

This is a wrapper for the ``pesq`` package [1]. Note that input will be moved to `cpu`
Expand Down Expand Up @@ -58,14 +62,14 @@ def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bo
If ``mode`` is not either ``"wb"`` or ``"nb"``

Example:
>>> from torchmetrics.functional.audio.pesq import pesq
>>> from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> pesq(preds, target, 8000, 'nb')
>>> perceptual_evaluation_speech_quality(preds, target, 8000, 'nb')
tensor(2.2076)
>>> pesq(preds, target, 16000, 'wb')
>>> perceptual_evaluation_speech_quality(preds, target, 16000, 'wb')
tensor(1.7359)

References:
Expand Down Expand Up @@ -98,3 +102,23 @@ def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bo
pesq_val = pesq_val.to(preds.device)

return pesq_val


@deprecated(target=perceptual_evaluation_speech_quality, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False) -> Tensor:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
r"""PESQ (Perceptual Evaluation of Speech Quality)

.. deprecated:: v0.7
Use :func:`torchmetrics.functional.audio.perceptual_evaluation_speech_quality`. Will be removed in v0.8.

Example:
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> pesq(preds, target, 8000, 'nb')
tensor(2.2076)
>>> pesq(preds, target, 16000, 'wb')
tensor(1.7359)
"""
return void(preds, target, fs, mode, keep_same_device)