From d03906bbc1372c93c04b9cae5f5e98d713f89926 Mon Sep 17 00:00:00 2001 From: SevKod Date: Tue, 9 May 2023 09:16:19 +0200 Subject: [PATCH] implement wavlm inside PyanNet class and add wavlm block --- pyannote/audio/models/blocks/wavlm.py | 44 +++++++++ pyannote/audio/models/segmentation/PyanNet.py | 90 ++++++------------- 2 files changed, 71 insertions(+), 63 deletions(-) create mode 100644 pyannote/audio/models/blocks/wavlm.py diff --git a/pyannote/audio/models/blocks/wavlm.py b/pyannote/audio/models/blocks/wavlm.py new file mode 100644 index 000000000..5c79cb286 --- /dev/null +++ b/pyannote/audio/models/blocks/wavlm.py @@ -0,0 +1,44 @@ +# MIT License +# +# Copyright (c) 2020 CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModel + +class WavLM(nn.Module): + + def __init__(self): + super().__init__() + + self.wvlm = AutoModel.from_pretrained('microsoft/wavlm-base') #Load the model + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + + waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample) + with torch.no_grad(): + outputs = self.wvlm(waveforms).extract_features #Compute the features and extract last hidden layer weights + + return (outputs) diff --git a/pyannote/audio/models/segmentation/PyanNet.py b/pyannote/audio/models/segmentation/PyanNet.py index 10991a1f6..9ebaddbe0 100644 --- a/pyannote/audio/models/segmentation/PyanNet.py +++ b/pyannote/audio/models/segmentation/PyanNet.py @@ -1,26 +1,3 @@ -# MIT License -# -# Copyright (c) 2020 CNRS -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - from typing import Optional import torch @@ -32,22 +9,9 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task from pyannote.audio.models.blocks.sincnet import SincNet +from pyannote.audio.models.blocks.wavlm import WavLM from pyannote.audio.utils.params import merge_dict -##WAVLM_BASE -#Requires to pass the PyanNet model to cuda during training script - -#Model is loaded outside of the PyanNet class - -from transformers import AutoModel - -#Loading the model from HuggingFace (requires git lfs to load the .bin checkpoint) -#model = AutoModel.from_pretrained('/content/drive/MyDrive/PyanNet/wavlm-base') - -model = AutoModel.from_pretrained('microsoft/wavlm-base') - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model.to(device) #Pass the model to the gpu (supposing that accelerator = gpu in the TorchLightning Trainer) class PyanNet(Model): """PyanNet segmentation model @@ -76,7 +40,6 @@ class PyanNet(Model): """ SINCNET_DEFAULTS = {"stride": 10} - LSTM_DEFAULTS = { "hidden_size": 128, "num_layers": 2, @@ -88,6 +51,7 @@ class PyanNet(Model): def __init__( self, + model: str = None, sincnet: dict = None, lstm: dict = None, linear: dict = None, @@ -104,15 +68,21 @@ def __init__( lstm["batch_first"] = True linear = merge_dict(self.LINEAR_DEFAULTS, linear) self.save_hyperparameters("sincnet", "lstm", "linear") + self.model = model + + if model == "wavlm": + self.wavlm = WavLM() + feat_size = 512 + else : + self.sincnet = SincNet(**self.hparams.sincnet) + feat_size = 60 - self.sincnet = SincNet(**self.hparams.sincnet) - - monolithic = lstm["monolithic"] if monolithic: multi_layer_lstm = dict(lstm) del multi_layer_lstm["monolithic"] - self.lstm = nn.LSTM(512, **multi_layer_lstm) + self.lstm = nn.LSTM(feat_size, **multi_layer_lstm) + else: num_layers = lstm["num_layers"] if num_layers > 1: @@ -126,7 +96,7 @@ def __init__( self.lstm = nn.ModuleList( [ nn.LSTM( - 512 + feat_size if i == 0 else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), **one_layer_lstm @@ -182,34 +152,28 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor: ------- scores : (batch, frame, classes) """ - #outputs = self.sincnet(waveforms) - - #WavLM feature extraction - - waveforms = torch.squeeze(waveforms,1) #waveforms : (batch, channel, sample) -> (batch,sample) - with torch.no_grad(): - feat = model(waveforms) #Compute the features and extract last hidden layer weights - - outputs = feat.extract_features #Get the features : outputs : (batch, frame, feature) - + if self.model == "wavlm" : + outputs = self.wavlm(waveforms) + else : + outputs = self.sincnet(waveforms) + if self.hparams.lstm["monolithic"]: - #No need to rearrange the output, as the features are already structured in (batch frame feature) - - #outputs, _ = self.lstm( - # rearrange(outputs, "batch feature frame -> batch frame feature")) - outputs, _ = self.lstm(outputs) - + if self.model == "wavlm": + outputs, _ = self.lstm(outputs) + else: + outputs, _ = self.lstm( + rearrange(outputs, "batch feature frame -> batch frame feature") + ) else: - #outputs = rearrange(outputs, "batch feature frame -> batch frame feature").cuda() + if self.model != "wavlm": + outputs = rearrange(outputs, "batch feature frame -> batch frame feature") for i, lstm in enumerate(self.lstm): outputs, _ = lstm(outputs) if i + 1 < self.hparams.lstm["num_layers"]: outputs = self.dropout(outputs) - - + if self.hparams.linear["num_layers"] > 0: for linear in self.linear: outputs = F.leaky_relu(linear(outputs)) return self.activation(self.classifier(outputs)) -