Skip to content

Commit

Permalink
added dynamic_heads
Browse files Browse the repository at this point in the history
-added parameter, `dynamic_heads`, to `transcribe()` and `align()`; not supported on faster-whisper and HF models
  • Loading branch information
jianfch committed Sep 30, 2024
1 parent 453013c commit 32235fa
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 50 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ Docstrings:
Whether to ignore warnings for compatibility issues with the detected Whisper version.
extra_models : list of whisper.model.Whisper, optional
List of additional Whisper model instances to use for computing word-timestamps along with ``model``.
dynamic_heads : bool or int or str, optional
Whether to find optimal cross-attention heads during runtime instead of using the predefined heads for
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
To specify number of iterations for finding the optimal heads,
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
decode_options
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.

Expand Down Expand Up @@ -935,6 +940,11 @@ Docstring:
Only if ``presplit=True``, ``gap_padding`` is prepended to each segments for word timing alignment.
Used to reduce the probability of model predicting timestamps earlier than the first utterance.
Ignored if ``model`` is a faster-whisper model.
dynamic_heads : bool or int or str, optional
Whether to find optimal cross-attention heads during runtime instead of using the predefined heads for
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
To specify number of iterations for finding the optimal heads,
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).

Returns
-------
Expand Down
11 changes: 9 additions & 2 deletions stable_whisper/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def align(
failure_threshold: Optional[float] = None,
extra_models: Optional[List["Whisper"]] = None,
presplit: Union[bool, List[str]] = True,
gap_padding: str = ' ...'
gap_padding: str = ' ...',
dynamic_heads: Optional[Union[bool, int, str]] = None
) -> Union[WhisperResult, None]:
"""
Align plain text or tokens with audio at word-level.
Expand Down Expand Up @@ -172,6 +173,11 @@ def align(
Only if ``presplit=True``, ``gap_padding`` is prepended to each segments for word timing alignment.
Used to reduce the probability of model predicting timestamps earlier than the first utterance.
Ignored if ``model`` is a faster-whisper model.
dynamic_heads : bool or int or str, optional
Whether to find optimal cross-attention heads during runtime instead of using the predefined heads for
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
To specify number of iterations for finding the optimal heads,
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
Returns
-------
Expand Down Expand Up @@ -357,7 +363,8 @@ def timestamp_words():
append_punctuations=append_punctuations,
gap_padding=gap_padding if presplit else None,
extra_models=extra_models,
pad_first_seg=pad_first_seg
pad_first_seg=pad_first_seg,
dynamic_heads=dynamic_heads
)
if len(temp_segments) == 1:
return temp_segments[0]
Expand Down
161 changes: 114 additions & 47 deletions stable_whisper/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
import torch
import numpy as np
from typing import TYPE_CHECKING, List, Callable, Optional, Tuple
from typing import TYPE_CHECKING, List, Callable, Optional, Union
from itertools import chain
from dataclasses import dataclass

Expand All @@ -21,48 +21,112 @@ class WordTiming:
probability: float


def _new_cache(audio_features=None, extras: int = None) -> dict:
return dict(
audio_features=audio_features,
jump_indices=None,
text_token_probs=None,
qks=None,
extra_caches=[_new_cache() for _ in range(extras)] if extras else None
)


def _compute_qks(
model: "Whisper",
tokenizer: "Tokenizer",
text_tokens: List[int],
mel: torch.Tensor,
num_samples: int,
tokens: torch.tensor,
medfilt_width: int = 7,
qk_scale: float = 1.0,
audio_features: torch.Tensor = None,
) -> Tuple[torch.Tensor, List[float]]:

cache: dict
):
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
cache['qks'] = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
lambda _, ins, outs, index=i: cache['qks'].__setitem__(index, outs[-1])
)
for i, block in enumerate(model.decoder.blocks)
]

with torch.no_grad():
if audio_features is None:
audio_features = model.encoder(mel.unsqueeze(0))
if (audio_features := cache['audio_features']) is None:
audio_features = cache['audio_features'] = model.encoder(mel.unsqueeze(0))
logits = model.decoder(tokens.unsqueeze(0), audio_features)[0]
sampled_logits = logits[len(tokenizer.sot_sequence):, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
cache['text_token_probs'] = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()

for hook in hooks:
hook.remove()

# heads * tokens * frames
weights = torch.cat([QKs[_l][:, _h] for _l, _h in model.alignment_heads.indices().T], dim=0)
weights = weights[:, len(tokenizer.sot_sequence): -1, : round(num_samples / N_SAMPLES_PER_TOKEN)]
weights = (weights * qk_scale).softmax(dim=-1)

def _compute_atten_weights(
model: "Whisper",
tokenizer: "Tokenizer",
text_tokens: List[int],
mel: torch.Tensor,
num_samples: int,
tokens: torch.tensor,
cache: dict,
medfilt_width: int = 7,
qk_scale: float = 1.0,
dynamic_heads_count: Optional[int] = None
) -> torch.Tensor:
if cache['qks'] is None:
_compute_qks(model, tokenizer, text_tokens, mel, tokens, cache)
QKs = cache['qks']
if dynamic_heads_count:
max_qk_len = round(num_samples / N_SAMPLES_PER_TOKEN)
if not cache.get('is_processed_qks'):
QKs = torch.cat([qk[0, :, len(tokenizer.sot_sequence): -1, : max_qk_len] for qk in QKs])
QKs = cache['qks'] = (QKs * qk_scale).softmax(dim=-1)
cache['is_processed_qks'] = True

if cache['jump_indices'] is None:
peaks = QKs.topk(1, dim=-1).indices
else:
jump_indices = np.pad(cache['jump_indices'], (0, 1), constant_values=max_qk_len)
peaks = jump_indices[:-1] + ((jump_indices[1:] - jump_indices[:-1]) * 0.5)
peaks = torch.from_numpy(peaks).to(QKs.device)[None, :, None]
distances = (peaks.expand_as(QKs) - torch.arange(QKs.size(-1), device=QKs.device)).abs() / 1500
scores = (distances * QKs).sum(dim=-1)
heads = [score.topk(dynamic_heads_count, largest=False).indices for score in scores.T]
weights = torch.stack([QKs[_h, i] for i, _h in enumerate(heads)], dim=1)
else:
weights = torch.cat([QKs[_l][:, _h] for _l, _h in model.alignment_heads.indices().T], dim=0)
weights = weights[:, len(tokenizer.sot_sequence): -1, : round(num_samples / N_SAMPLES_PER_TOKEN)]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)

return weights, text_token_probs
return weights


def _compute_jump_indices(
model: "Whisper",
cache: dict,
extra_models: List["Whisper"] = None,
**kwargs
):
weights = _compute_atten_weights(model, cache=cache, **kwargs)
if extra_models:
extra_weights = [weights]
for mi, other_model in enumerate(extra_models):
m = _compute_atten_weights(other_model, cache=cache['extra_caches'][mi], **kwargs)
extra_weights.append(m)
weights = torch.cat(extra_weights, dim=0)
extra_text_token_probs = [c['text_token_probs'] for c in cache['extra_caches']] + [cache['text_token_probs']]
cache['text_token_probs'] = torch.tensor(
extra_text_token_probs,
device=extra_weights[0].device
).mean(dim=0).tolist()

matrix = weights.mean(dim=0)
text_indices, time_indices = dtw(-matrix)

jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
cache['jump_indices'] = time_indices[jumps].clip(min=0)


# modified version of whisper.timing.find_alignment
Expand All @@ -79,7 +143,8 @@ def find_alignment_stable(
ts_noise: float = None,
token_split=None,
audio_features: torch.Tensor = None,
extra_models: List["Whisper"] = None
extra_models: List["Whisper"] = None,
dynamic_heads: Optional[Union[bool, int, str]] = None
) -> List[WordTiming]:
if extra_models and (invalid_model_types := set(map(type, extra_models)) - {type(model)}):
raise NotImplementedError(f'Got unsupported model type(s): {invalid_model_types}')
Expand All @@ -99,45 +164,47 @@ def find_alignment_stable(
]
).to(model.device)

if token_split is None:
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
else:
words, word_tokens = token_split
words.append(tokenizer.decode([tokenizer.eot]))
word_tokens.append([tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
if dynamic_heads:
if dynamic_heads is True:
dynamic_heads_count = 6
dynamic_iterations = None
elif isinstance(dynamic_heads, int):
dynamic_heads_count = dynamic_heads
dynamic_iterations = None
else:
assert ',' in dynamic_heads
dynamic_heads = dynamic_heads.split(',')
dynamic_heads_count = int(dynamic_heads[0])
dynamic_iterations = int(dynamic_heads[1])
else:
dynamic_heads_count = dynamic_iterations = None
kwargs = dict(
model=model,
tokenizer=tokenizer,
text_tokens=text_tokens,
mel=mel,
num_samples=num_samples,
tokens=tokens,
qk_scale=qk_scale,
medfilt_width=medfilt_width
medfilt_width=medfilt_width,
extra_models=extra_models,
dynamic_heads_count=dynamic_heads_count
)

weights, text_token_probs = _compute_qks(model, audio_features=audio_features, **kwargs)

if extra_models:
extra_weights = [weights]
extra_text_token_probs = [text_token_probs]
for other_model in extra_models:
m, p = _compute_qks(other_model, **kwargs)
extra_weights.append(m)
extra_text_token_probs.append(p)
weights = torch.cat(extra_weights, dim=0)
text_token_probs = torch.tensor(extra_text_token_probs, device=extra_weights[0].device).mean(dim=0).tolist()

matrix = weights.mean(dim=0)
text_indices, time_indices = dtw(-matrix)

if token_split is None:
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
else:
words, word_tokens = token_split
words.append(tokenizer.decode([tokenizer.eot]))
word_tokens.append([tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))

jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps].clip(min=0) / TOKENS_PER_SECOND
cache = _new_cache(audio_features=audio_features, extras=0 if extra_models is None else len(extra_models))
for _ in range(dynamic_iterations or 1):
_compute_jump_indices(cache=cache, **kwargs)
jump_times = cache['jump_indices'] / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
np.mean(cache['text_token_probs'][i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]

Expand Down
9 changes: 8 additions & 1 deletion stable_whisper/whisper_word_level/original_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def transcribe_stable(
progress_callback: Callable = None,
ignore_compatibility: bool = False,
extra_models: Optional[List["Whisper"]] = None,
dynamic_heads: Optional[Union[bool, int, str]] = None,
**decode_options) \
-> WhisperResult:
"""
Expand Down Expand Up @@ -191,6 +192,11 @@ def transcribe_stable(
Whether to ignore warnings for compatibility issues with the detected Whisper version.
extra_models : list of whisper.model.Whisper, optional
List of additional Whisper model instances to use for computing word-timestamps along with ``model``.
dynamic_heads : bool or int or str, optional
Whether to find optimal cross-attention heads during runtime instead of using the predefined heads for
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
To specify number of iterations for finding the optimal heads,
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
decode_options
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
Expand Down Expand Up @@ -590,7 +596,8 @@ def fast_forward():
ts_noise=ts_noise,
split_callback=split_callback,
gap_padding=gap_padding,
extra_models=extra_models
extra_models=extra_models,
dynamic_heads=dynamic_heads
)

for i in reversed(range(len(current_segments))):
Expand Down

0 comments on commit 32235fa

Please sign in to comment.