Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore the NoUtterancesError when calculating pesq for a batch #2753

Merged
merged 15 commits into from
Oct 9, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765))


- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753))


## [1.4.2] - 2022-09-12

### Added
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -83,6 +85,11 @@ def perceptual_evaluation_speech_quality(
)
import pesq as pesq_backend

def _issubtype_number(x: Any) -> bool:
return np.issubdtype(type(x), np.number)

_filter_error_msg = np.vectorize(_issubtype_number)

if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
if mode not in ("wb", "nb"):
Expand All @@ -103,8 +110,8 @@ def perceptual_evaluation_speech_quality(
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])
pesq_val = torch.from_numpy(pesq_val_np[_filter_error_msg(pesq_val_np)].astype(np.float32))
pesq_val = pesq_val.reshape(len(pesq_val))

if keep_same_device:
return pesq_val.to(preds.device)
Expand Down
Loading