diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e7d220fd..a6f2161c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,15 +32,16 @@ ### Features and improvements + - feat(task): add [powerset](https://www.isca-speech.org/archive/interspeech_2023/plaquet23_interspeech.html) support to `SpeakerDiarization` task - feat(task): add support for multi-task models + - feat(task): add support for label scope in speaker diarization task + - feat(task): add support for missing classes in multi-label segmentation task + - feat(model): add segmentation model based on torchaudio self-supervised representation - feat(pipeline): send pipeline to device with `pipeline.to(device)` - - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - - feat(task): add [powerset](https://arxiv.org/PLACEHOLDER) support to `SpeakerDiarization` task - feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline + - feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`) - feat(pipeline): add progress hook to pipelines - feat(pipeline): check version compatibility at load time - - feat(task): add support for label scope in speaker diarization task - - feat(task): add support for missing classes in multi-label segmentation task - improve(task): load metadata as tensors rather than pyannote.core instances - improve(task): improve error message on missing specifications diff --git a/pyannote/audio/cli/train_config/model/SSeRiouSS.yaml b/pyannote/audio/cli/train_config/model/SSeRiouSS.yaml new file mode 100644 index 000000000..73f7f963a --- /dev/null +++ b/pyannote/audio/cli/train_config/model/SSeRiouSS.yaml @@ -0,0 +1,13 @@ +# @package _group_ +_target_: pyannote.audio.models.segmentation.SSeRiouSS +wav2vec: WAVLM_BASE +wav2vec_layer: -1 +lstm: + hidden_size: 128 + num_layers: 4 + bidirectional: true + monolithic: true + dropout: 0.5 +linear: + hidden_size: 128 + num_layers: 2 diff --git a/pyannote/audio/models/segmentation/SSeRiouSS.py b/pyannote/audio/models/segmentation/SSeRiouSS.py new file mode 100644 index 000000000..7cd545177 --- /dev/null +++ b/pyannote/audio/models/segmentation/SSeRiouSS.py @@ -0,0 +1,234 @@ +# MIT License +# +# Copyright (c) 2023- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from pyannote.core.utils.generators import pairwise + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.utils.params import merge_dict + + +class SSeRiouSS(Model): + """Self-Supervised Representation for Speaker Segmentation + + wav2vec > LSTM > Feed forward > Classifier + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + wav2vec: dict or str, optional + Defaults to "WAVLM_BASE". + wav2vec_layer: int, optional + Index of layer to use as input to the LSTM. + Defaults (-1) to use average of all layers (with learnable weights). + lstm : dict, optional + Keyword arguments passed to the LSTM layer. + Defaults to {"hidden_size": 128, "num_layers": 4, "bidirectional": True}, + i.e. two bidirectional layers with 128 units each. + Set "monolithic" to False to split monolithic multi-layer LSTM into multiple mono-layer LSTMs. + This may proove useful for probing LSTM internals. + linear : dict, optional + Keyword arugments used to initialize linear layers + Defaults to {"hidden_size": 128, "num_layers": 2}, + i.e. two linear layers with 128 units each. + """ + + WAV2VEC_DEFAULTS = "WAVLM_BASE" + + LSTM_DEFAULTS = { + "hidden_size": 128, + "num_layers": 4, + "bidirectional": True, + "monolithic": True, + "dropout": 0.0, + } + LINEAR_DEFAULTS = {"hidden_size": 128, "num_layers": 2} + + def __init__( + self, + wav2vec: Union[dict, str] = None, + wav2vec_layer: int = -1, + lstm: dict = None, + linear: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + task: Optional[Task] = None, + ): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + if isinstance(wav2vec, str): + # `wav2vec` is one of the supported pipelines from torchaudio (e.g. "WAVLM_BASE") + if hasattr(torchaudio.pipelines, wav2vec): + bundle = getattr(torchaudio.pipelines, wav2vec) + if sample_rate != bundle._sample_rate: + raise ValueError( + f"Expected {bundle._sample_rate}Hz, found {sample_rate}Hz." + ) + wav2vec_dim = bundle._params["encoder_embed_dim"] + wav2vec_num_layers = bundle._params["encoder_num_layers"] + self.wav2vec = bundle.get_model() + + # `wav2vec` is a path to a self-supervised representation checkpoint + else: + _checkpoint = torch.load(wav2vec) + wav2vec = _checkpoint.pop("config") + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + state_dict = _checkpoint.pop("state_dict") + self.wav2vec.load_state_dict(state_dict) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + # `wav2vec` is a config dictionary understood by `wav2vec2_model` + # this branch is typically used by Model.from_pretrained(...) + elif isinstance(wav2vec, dict): + self.wav2vec = torchaudio.models.wav2vec2_model(**wav2vec) + wav2vec_dim = wav2vec["encoder_embed_dim"] + wav2vec_num_layers = wav2vec["encoder_num_layers"] + + if wav2vec_layer < 0: + self.wav2vec_weights = nn.Parameter( + data=torch.ones(wav2vec_num_layers), requires_grad=True + ) + + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + lstm["batch_first"] = True + linear = merge_dict(self.LINEAR_DEFAULTS, linear) + + self.save_hyperparameters("wav2vec", "wav2vec_layer", "lstm", "linear") + + monolithic = lstm["monolithic"] + if monolithic: + multi_layer_lstm = dict(lstm) + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(wav2vec_dim, **multi_layer_lstm) + + else: + num_layers = lstm["num_layers"] + if num_layers > 1: + self.dropout = nn.Dropout(p=lstm["dropout"]) + + one_layer_lstm = dict(lstm) + one_layer_lstm["num_layers"] = 1 + one_layer_lstm["dropout"] = 0.0 + del one_layer_lstm["monolithic"] + + self.lstm = nn.ModuleList( + [ + nn.LSTM( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + **one_layer_lstm, + ) + for i in range(num_layers) + ] + ) + + if linear["num_layers"] < 1: + return + + lstm_out_features: int = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + lstm_out_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + + def build(self): + if self.hparams.linear["num_layers"] > 0: + in_features = self.hparams.linear["hidden_size"] + else: + in_features = self.hparams.lstm["hidden_size"] * ( + 2 if self.hparams.lstm["bidirectional"] else 1 + ) + + if isinstance(self.specifications, tuple): + raise ValueError("SSeRiouSS model does not support multi-tasking.") + + if self.specifications.powerset: + out_features = self.specifications.num_powerset_classes + else: + out_features = len(self.specifications.classes) + + self.classifier = nn.Linear(in_features, out_features) + self.activation = self.default_activation() + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + """ + + num_layers = ( + None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer + ) + + with torch.no_grad(): + outputs, _ = self.wav2vec.extract_features( + waveforms.squeeze(1), num_layers=num_layers + ) + + if num_layers is None: + outputs = torch.stack(outputs, dim=-1) @ F.softmax( + self.wav2vec_weights, dim=0 + ) + else: + outputs = outputs[-1] + + if self.hparams.lstm["monolithic"]: + outputs, _ = self.lstm(outputs) + else: + for i, lstm in enumerate(self.lstm): + outputs, _ = lstm(outputs) + if i + 1 < self.hparams.lstm["num_layers"]: + outputs = self.dropout(outputs) + + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + outputs = F.leaky_relu(linear(outputs)) + + return self.activation(self.classifier(outputs)) diff --git a/pyannote/audio/models/segmentation/__init__.py b/pyannote/audio/models/segmentation/__init__.py index 82e149853..9f6f5f6e3 100644 --- a/pyannote/audio/models/segmentation/__init__.py +++ b/pyannote/audio/models/segmentation/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,5 +21,6 @@ # SOFTWARE. from .PyanNet import PyanNet +from .SSeRiouSS import SSeRiouSS -__all__ = ["PyanNet"] +__all__ = ["PyanNet", "SSeRiouSS"]