-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
-changed python requirement from 3.7+ to 3.8+ (following Whisper) -more reliable word-level timestamps (using Whisper's new method for word timestamps) -transcribe() returns WhisperResult object (allowing easier to manipulation of results) -WhisperResult contains methods to save result as JSON, SRT, VTT, ASS -WhisperResult contains methods to regroup segments word by word -added Silero VAD for generating suppression mask (requires PyTorch 1.2.0+) -improved non-vad suppression -added visualize_suppression() for visualizing suppression based on arguments (requires Pillow or opencv-python) -SRT/VTT/ASS outputs now all support both segment-level and word-level
- Loading branch information
Showing
14 changed files
with
1,993 additions
and
2,237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from .stabilization import * | ||
from .text_output import * | ||
from .whisper_word_level import * | ||
from .result import * | ||
from .text_output import * | ||
from .video_output import * | ||
from .stabilization import visualize_suppression | ||
from ._version import __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "1.4.0" | ||
__version__ = "2.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from typing import TYPE_CHECKING, List, Union | ||
from dataclasses import replace | ||
|
||
import torch | ||
import numpy as np | ||
|
||
from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult | ||
|
||
|
||
if TYPE_CHECKING: | ||
from whisper.model import Whisper | ||
|
||
|
||
def _suppress_ts(ts_logits: torch.Tensor, ts_token_mask: torch.Tensor = None): | ||
if ts_token_mask is not None: | ||
ts_logits[:, ts_token_mask] = -np.inf | ||
|
||
|
||
# modified version of whisper.decoding.DecodingTask | ||
class DecodingTaskStable(DecodingTask): | ||
|
||
def __init__(self, *args, **kwargs): | ||
self.ts_token_mask: torch.Tensor = kwargs.pop('ts_token_mask', None) | ||
self.audio_features: torch.Tensor = kwargs.pop('audio_features', None) | ||
super(DecodingTaskStable, self).__init__(*args, **kwargs) | ||
|
||
def _get_audio_features(self, mel: torch.Tensor): | ||
if self.audio_features is None: | ||
audio_features = super()._get_audio_features(mel) | ||
self.audio_features = audio_features.detach().clone() | ||
return audio_features | ||
return self.audio_features.clone() | ||
|
||
# modified version of whisper.DecodingTask._main_loop | ||
def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor): | ||
assert audio_features.shape[0] == tokens.shape[0] | ||
n_batch = tokens.shape[0] | ||
sum_logprobs: torch.Tensor = torch.zeros(n_batch, device=audio_features.device) | ||
no_speech_probs = [np.nan] * n_batch | ||
|
||
try: | ||
for i in range(self.sample_len): | ||
logits = self.inference.logits(tokens, audio_features) | ||
|
||
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs | ||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) | ||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() | ||
|
||
# now we need to consider the logits at the last token only | ||
logits = logits[:, -1] | ||
|
||
# apply the logit filters, e.g. for suppressing or applying penalty to | ||
for logit_filter in self.logit_filters: | ||
logit_filter.apply(logits, tokens) | ||
|
||
# suppress timestamp tokens where the audio is silent so that decoder ignores those timestamps | ||
_suppress_ts(logits[:, self.tokenizer.timestamp_begin:], self.ts_token_mask) | ||
|
||
# expand the tokens tensor with the selected next tokens | ||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) | ||
|
||
if completed or tokens.shape[-1] > self.n_ctx: | ||
break | ||
finally: | ||
self.inference.cleanup_caching() | ||
|
||
return tokens, sum_logprobs, no_speech_probs | ||
|
||
|
||
# modified version of whisper.decoding.decode | ||
@torch.no_grad() | ||
def decode_stable(model: "Whisper", | ||
mel: torch.Tensor, | ||
options: DecodingOptions = DecodingOptions(), | ||
ts_token_mask: torch.Tensor = None, | ||
audio_features: torch.Tensor = None, | ||
**kwargs, ) -> \ | ||
Union[DecodingResult, List[DecodingResult], tuple]: | ||
""" | ||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). | ||
Parameters | ||
---------- | ||
model: Whisper | ||
The Whisper model modified instance | ||
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) | ||
A tensor containing the Mel spectrogram(s) | ||
options: DecodingOptions | ||
A dataclass that contains all necessary options for decoding 30-second segments | ||
ts_token_mask: torch.Tensor | ||
Mask for suppressing to timestamp token(s) for decoding | ||
audio_features: torch.Tensor | ||
reused audio_feature from encoder for fallback | ||
Returns | ||
------- | ||
result: Union[DecodingResult, List[DecodingResult]] | ||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s) | ||
""" | ||
if single := mel.ndim == 2: | ||
mel = mel.unsqueeze(0) | ||
|
||
if kwargs: | ||
options = replace(options, **kwargs) | ||
|
||
task = DecodingTaskStable(model, options, ts_token_mask=ts_token_mask, audio_features=audio_features) | ||
result = task.run(mel) | ||
|
||
return result[0] if single else result, task.audio_features |
Oops, something went wrong.