Skip to content

Commit

Permalink
Whisper encoder + No 30 second padding (#5)
Browse files Browse the repository at this point in the history
* enable whisper model (no need for max_padding)

* bugfix: boolq_in didn't have audio_transcript -> use _get_transcribe_sample

* bugfix: move model to device before merging lora weights

* rename modified_whisper -> whisper_model_modified
  • Loading branch information
farzadab authored Jun 4, 2024
1 parent 56a8209 commit edc3797
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 11 deletions.
13 changes: 6 additions & 7 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import soundfile as sf
import streaming as mds
import torch
import torch.nn.functional as F
import transformers
from torch.utils import data

Expand Down Expand Up @@ -71,8 +72,10 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
def __call__(self, features, *args, **kwargs):
audio_features = [f.pop("audio_values") for f in features]
batch = super().__call__(features, *args, **kwargs)
batch["audio_values"] = torch.nn.utils.rnn.pad_sequence(
audio_features, batch_first=True
# Pad the last dimension of all audio_values to the same length, with 0s on the right.
max_len = max([x.shape[-1] for x in audio_features])
batch["audio_values"] = torch.stack(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features]
)
return batch

Expand Down Expand Up @@ -434,11 +437,7 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:

class BoolQInputDataset(BoolQDataset):
def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
audio_transcript = str(row["question"])
return VoiceSample(
self._get_transcribe_messages(idx, audio_transcript),
self._get_audio(row),
)
return self._get_transcribe_sample(idx, row, tcol="question")


class LibriSpeechDataset(VoiceDataset):
Expand Down
1 change: 1 addition & 0 deletions ultravox/inference/ultravox_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
model = ultravox_model.UltravoxModel.from_pretrained(
model_path, torch_dtype=dtype
)
model.to(dtype=dtype, device=device)
model.merge_and_unload()

tokenizer_id = tokenizer_id or model_path
Expand Down
15 changes: 13 additions & 2 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import transformers.models

from ultravox.model import ultravox_config
from ultravox.model import whisper_model_modified


class UltravoxModel(
Expand Down Expand Up @@ -190,9 +191,19 @@ def _create_audio_tower(cls, config: ultravox_config.UltravoxConfig) -> Union[
transformers.models.whisper.modeling_whisper.WhisperEncoder,
]:
if config.audio_model_id is not None:
audio_tower = transformers.AutoModel.from_pretrained(config.audio_model_id)
if "whisper" in config.audio_model_id is not None:
audio_tower = whisper_model_modified.WhisperEncoder.from_pretrained(
config.audio_model_id
)
else:
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id
)
else:
audio_tower = transformers.AutoModel.from_config(config.audio_config)
if "whisper" in config.audio_config._name_or_path:
audio_tower = whisper_model_modified.WhisperEncoder(config.audio_config)
else:
audio_tower = transformers.AutoModel.from_config(config.audio_config)

if isinstance(
audio_tower,
Expand Down
16 changes: 14 additions & 2 deletions ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self,
audio_processor=None,
tokenizer=None,
audio_padding: str = "longest",
encoder_ds_factor: int = 320,
stack_factor: int = 8,
audio_placeholder: str = "<|audio|>",
Expand All @@ -43,10 +44,12 @@ def __init__(
Args:
audio_processor: The audio processor for the audio encoder.
tokenizer: The tokenizer for the language model.
audio_padding: The padding strategy for the audio encoder.
encoder_ds_factor: The downsample factor of the audio encoder.
stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
audio_placeholder: The placeholder for the audio in the text.
"""
self.audio_padding = audio_padding
self.encoder_ds_factor = encoder_ds_factor
self.stack_factor = stack_factor
self.audio_placeholder = audio_placeholder
Expand Down Expand Up @@ -107,7 +110,12 @@ def __call__(
data = {}
audio_embed_frames = 0
if audio is not None and len(audio) > 0:
audio_len = audio.shape[-1]
if self.audio_padding == "max_length":
# 30 seconds is the expected length for Whisper
assert sampling_rate is not None, "Sampling rate must be provided."
audio_len = 30 * sampling_rate
else:
audio_len = audio.shape[-1]
# It's guaranteed that the number of frames is less than or equal to this amount.
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
Expand All @@ -116,7 +124,11 @@ def __call__(
data["audio_token_len"] = [audio_embed_frames]

x = self.audio_processor(
audio, sampling_rate=sampling_rate, padding="longest", **kwargs
audio,
sampling_rate=sampling_rate,
padding="longest",
max_length=audio_len,
**kwargs,
)
if "input_features" in x:
data["audio_values"] = x.input_features
Expand Down
141 changes: 141 additions & 0 deletions ultravox/model/whisper_model_modified.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
# see this issue for the commentary: https://github.com/huggingface/transformers/issues/25744
#
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import transformers
import transformers.modeling_outputs
from transformers.models.whisper import modeling_whisper as whisper


class WhisperEncoder(whisper.WhisperEncoder):
"""
Encoder portion of OpenAI's Whisper model.
This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
2. allow less than 30 second of audio padding to be passed in:
- relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
- embed_pos is now sliced to match the length of `inputs_embeds`
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
"""

base_model_prefix = "model.encoder"

def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
expected_seq_length = (
self.config.max_source_positions
* self.conv1.stride[0]
* self.conv2.stride[0]
)
if input_features.shape[-1] > expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
)

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]

hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."

for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True

if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=(
head_mask[idx] if head_mask is not None else None
),
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return transformers.modeling_outputs.BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)

0 comments on commit edc3797

Please sign in to comment.