Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add SpecAugment #21298

Merged
1 change: 1 addition & 0 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] WhisperModel
- forward
- _mask_input_features

## WhisperForConditionalGeneration

Expand Down
44 changes: 44 additions & 0 deletions src/transformers/models/whisper/configuration_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,35 @@ class WhisperConfig(PretrainedConfig):
begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
the token for `" "` (`blank_token_id`) and the `eos_token_id`
apply_spec_augment (`bool`, *optional*, defaults to `False`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition](https://arxiv.org/abs/1904.08779).
mask_time_prob (`float`, *optional*, defaults to 0.05):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`.
mask_time_length (`int`, *optional*, defaults to 10):
Length of vector span along the time axis.
mask_time_min_masks (`int`, *optional*, defaults to 2),:
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
mask_time_min_masks''
mask_feature_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
bofenghuang marked this conversation as resolved.
Show resolved Hide resolved
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
True`.
mask_feature_length (`int`, *optional*, defaults to 10):
Length of vector span along the feature axis.
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
bofenghuang marked this conversation as resolved.
Show resolved Hide resolved


Example:
Expand Down Expand Up @@ -185,6 +214,13 @@ def __init__(
eos_token_id=50256,
suppress_tokens=None,
begin_suppress_tokens=[220, 50256],
apply_spec_augment=False,
mask_time_prob=0.05,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -208,6 +244,14 @@ def __init__(
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def __call__(
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
Expand All @@ -318,6 +319,10 @@ def __call__(
else:
padded_inputs["input_features"] = input_features

if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

Expand Down
182 changes: 180 additions & 2 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import random
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -97,6 +98,126 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.

Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape

if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")

if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)

# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()

def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)

# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length

# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)

return num_masked_span

# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)

# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []

max_num_masked_span = compute_num_masked_span(sequence_length)

if max_num_masked_span == 0:
return spec_aug_mask

for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)

# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)

# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]

spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)

spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)

# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)

# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1

# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)

return spec_aug_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point, the input features were converted potentially from a torch.tensor to numpy no? Also if it what on GPU, attention_mask.sum(-1).detach().tolist() will need the .cpu() call.
Handling devices is gonna be a hassle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function just generates a random NumPy mask, which will then be converted to tensor and mask the input features.

For attention_mask.sum(-1).detach().tolist(), not sure if it's a good practice, but tolist() automatically moves tensors to CPU if necessary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's fine to do this (copied from W2V2)



class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
Expand Down Expand Up @@ -503,6 +624,14 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
`[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.

Expand Down Expand Up @@ -999,11 +1128,56 @@ def freeze_encoder(self):
"""
self.encoder._freeze_parameters()

# Modified from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
bofenghuang marked this conversation as resolved.
Show resolved Hide resolved
def _mask_input_features(
self,
input_features: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""

# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return input_features

# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.size()

if self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
input_features[mask_time_indices] = 0

if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
input_features[mask_feature_indices] = 0

return input_features

@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1044,6 +1218,8 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)

encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
Expand Down Expand Up @@ -1139,7 +1315,8 @@ def freeze_encoder(self):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1193,6 +1370,7 @@ def forward(

outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
Expand Down
36 changes: 36 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def test_forward_signature(self):

expected_arg_names = [
"input_features",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
Expand Down Expand Up @@ -1289,3 +1290,38 @@ def test_tiny_timestamp_generation(self):

transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"
set_seed(0)
# Apply SpecAugment
model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True)
# Set model to training mode to enable SpecAugment
model.train()
model.to(torch_device)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features

with torch.no_grad():
logits = model(
input_features,
decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
use_cache=False,
)

# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847,
-0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712,
2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584,
1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241
]
)
# fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))