Skip to content

Commit

Permalink
Ignore the NoUtterancesError when calculating pesq for a batch (#2753)
Browse files Browse the repository at this point in the history
* filter nan

* use np

* fix tests

* fix dtype from object

* as func

* typing

* changelog

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka B <j.borovec+github@gmail.com>

(cherry picked from commit 8cf181f)
  • Loading branch information
veera-puthiran-14082 authored and Borda committed Oct 9, 2024
1 parent 4ebfdb8 commit 56a376c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765))


- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753))


## [1.4.2] - 2022-09-12

### Added
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -83,6 +85,11 @@ def perceptual_evaluation_speech_quality(
)
import pesq as pesq_backend

def _issubtype_number(x: Any) -> bool:
return np.issubdtype(type(x), np.number)

_filter_error_msg = np.vectorize(_issubtype_number)

if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
if mode not in ("wb", "nb"):
Expand All @@ -103,8 +110,8 @@ def perceptual_evaluation_speech_quality(
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])
pesq_val = torch.from_numpy(pesq_val_np[_filter_error_msg(pesq_val_np)].astype(np.float32))
pesq_val = pesq_val.reshape(len(pesq_val))

if keep_same_device:
return pesq_val.to(preds.device)
Expand Down

0 comments on commit 56a376c

Please sign in to comment.