From 0b45103cb228a81a9d9d776cca92694cb30ddb41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Sun, 22 Oct 2023 16:44:32 +0200 Subject: [PATCH] feat(utils): add "soft" option to Powerset.to_multilabel conversion (#1516) --- CHANGELOG.md | 1 + pyannote/audio/utils/powerset.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fcdebb82c..c63e65bea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - feat(pipeline): add support for list of hooks with `Hooks` - BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead) - fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization` +- feat(utils): add `"soft"` option to `Powerset.to_multilabel` ## Version 3.0.1 (2023-09-28) diff --git a/pyannote/audio/utils/powerset.py b/pyannote/audio/utils/powerset.py index 0f5cfb5bc..810519829 100644 --- a/pyannote/audio/utils/powerset.py +++ b/pyannote/audio/utils/powerset.py @@ -84,26 +84,32 @@ def build_cardinality(self) -> torch.Tensor: powerset_k += 1 return cardinality - def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor: - """Convert predictions from (soft) powerset to (hard) multi-label + def to_multilabel(self, powerset: torch.Tensor, soft: bool = False) -> torch.Tensor: + """Convert predictions from powerset to multi-label Parameter --------- powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor Soft predictions in "powerset" space. + soft : bool, optional + Return soft multi-label predictions. Defaults to False (i.e. hard predictions) + Assumes that `powerset` are "logits" (not "probabilities"). Returns ------- multi_label : (batch_size, num_frames, num_classes) torch.Tensor - Hard predictions in "multi-label" space. + Predictions in "multi-label" space. """ - hard_powerset = torch.nn.functional.one_hot( - torch.argmax(powerset, dim=-1), - self.num_powerset_classes, - ).float() + if soft: + powerset_probs = torch.exp(powerset) + else: + powerset_probs = torch.nn.functional.one_hot( + torch.argmax(powerset, dim=-1), + self.num_powerset_classes, + ).float() - return torch.matmul(hard_powerset, self.mapping) + return torch.matmul(powerset_probs, self.mapping) def forward(self, powerset: torch.Tensor) -> torch.Tensor: """Alias for `to_multilabel`"""