Skip to content

Commit

Permalink
make metric_mtx type and device correct (#533)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
5 people authored Sep 22, 2021
1 parent d6112a2 commit f9d7d5f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added Learned Perceptual Image Patch Similarity (LPIPS) ([#431](https://github.com/PyTorchLightning/metrics/issues/431))


- Added Tweedie Deviance Score ([#499](https://github.com/PyTorchLightning/metrics/pull/499))


Expand Down Expand Up @@ -42,6 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495))


- Fixed bug in `pit` by using the returned first result to initialize device and type ([#533](https://github.com/PyTorchLightning/metrics/pull/533))


- Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539))


Expand Down
15 changes: 11 additions & 4 deletions torchmetrics/functional/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,17 @@ def pit(

# calculate the metric matrix
batch_size, spk_num = target.shape[0:2]
metric_mtx = torch.empty((batch_size, spk_num, spk_num), dtype=preds.dtype, device=target.device)
for t in range(spk_num):
for e in range(spk_num):
metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs)
metric_mtx = None
for target_idx in range(spk_num): # we have spk_num speeches in target in each sample
for preds_idx in range(spk_num): # we have spk_num speeches in preds in each sample
if metric_mtx is not None:
metric_mtx[:, target_idx, preds_idx] = metric_func(
preds[:, preds_idx, ...], target[:, target_idx, ...], **kwargs
)
else:
first_ele = metric_func(preds[:, preds_idx, ...], target[:, target_idx, ...], **kwargs)
metric_mtx = torch.empty((batch_size, spk_num, spk_num), dtype=first_ele.dtype, device=first_ele.device)
metric_mtx[:, target_idx, preds_idx] = first_ele

# find best
op = torch.max if eval_func == "max" else torch.min
Expand Down

0 comments on commit f9d7d5f

Please sign in to comment.