Skip to content

Commit

Permalink
ver 2.0.0
Browse files Browse the repository at this point in the history
-changed python requirement from 3.7+ to 3.8+  (following Whisper)
-more reliable word-level timestamps (using Whisper's new method for word timestamps)
-transcribe() returns WhisperResult object (allowing easier to manipulation of results)
-WhisperResult contains methods to save result as JSON, SRT, VTT, ASS
-WhisperResult contains methods to regroup segments word by word
-added Silero VAD for generating suppression mask (requires PyTorch 1.2.0+)
-improved non-vad suppression
-added visualize_suppression() for visualizing suppression based on arguments (requires Pillow or opencv-python)
-SRT/VTT/ASS outputs now all support both segment-level and word-level
  • Loading branch information
jianfch committed Mar 17, 2023
1 parent eb5e68c commit 2248087
Show file tree
Hide file tree
Showing 14 changed files with 1,993 additions and 2,237 deletions.
97 changes: 71 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# Stabilizing Timestamps for Whisper

## Description
This script modifies and adds more robust decoding logic on top of OpenAI's Whisper to produce more accurate segment-level timestamps and obtain to word-level timestamps without extra inference.
This script modifies [OpenAI's Whisper](https://github.com/openai/whisper) to produce more reliable timestamps.

![image](https://user-images.githubusercontent.com/28970749/218944014-b915af81-1cf5-4522-a823-e0f476fcc550.png)
![image](./demo/jfk.PNG)


## Update:
The official [Whisper]() repo introduced word-level timestamps in a recent [commit](https://github.com/openai/whisper/commit/500d0fe9668fae5fe2af2b6a3c4950f8a29aa145) which produces more reliable timestamps than method used in this script.
This script has not been updated to utilize this new version of Whisper yet. It will be updated for next release, version 2.0.0.
./demo/jfk.mp4


### What's new in 2.0.0 ?
- updated to use Whisper's more reliable word-level timestamps method.
- the more reliable word timestamps allows regrouping segments word by word.
- can now suppress silence with [Silero VAD](https://github.com/snakers4/silero-vad) (requires PyTorch 1.2.0+)
- non-VAD silencing suppress is also more robust


./demo/a.mp4


### Features
- more control over the timestamps than default Whisper
- supports direct preprocessing with [Demucs](https://github.com/facebookresearch/demucs) to isolate voice
- support dynamic quantization to decrease memory usage for inference on CPU
- lower memory usage than default Whisper when transcribing very long input audio tracks

## Setup
```
Expand All @@ -21,52 +35,83 @@ pip install -U git+https://github.com/jianfch/stable-ts.git
```

### Command-line usage
Transcribe audio then save result as JSON file.
Transcribe audio then save result as JSON file which contains the original inference results.
This allows results to be reprocessed different without having to redo inference.
Change `audio.json` to `audio.srt` to process it directly into SRT.
```commandline
stable-ts audio.mp3 -o audio.json
```
Processing JSON file of the results into ASS.
Processing JSON file of the results into SRT.
```commandline
stable-ts audio.json -o audio.ass
stable-ts audio.json -o audio.srt
```
Transcribe multiple audio files then process the results directly into SRT files.
```commandline
stable-ts audio1.mp3 audio2.mp3 audio3.mp3 -o audio1.srt audio2.srt audio3.srt
```
Show all available arguments and help.
```commandline
stable-ts -h
```

### Python usage
```python
import stable_whisper

model = stable_whisper.load_model('base')
# modified model should run just like the regular model but accepts additional parameters
results = model.transcribe('audio.mp3')
result = model.transcribe('audio.mp3')
# srt/vtt
result.to_srt_vtt('audio.srt')
# ass
result.to_ass('audio.ass')
# json
result.save_as_json('audio.json')
```

https://user-images.githubusercontent.com/28970749/218942894-cb0b91df-1c14-4d2f-9793-d1c8ef20e711.mp4
### Regrouping Words
Stable-ts has a preset for regrouping word into different segments. This preset is enabled by `regroup=True`.
But are other built-in regrouping methods that allow you to customize the regrouping logic.
This preset is just a predefined a combination of those methods.


./demo/xata.mp4


```python
# the above uses default settings on version 1.1 with large model
# sentence/phrase-level
stable_whisper.results_to_sentence_srt(results, 'audio.srt')
result0 = model.transcribe('audio.mp3', regroup=True) # regroup is True by default
# regroup=True is same as below
result1 = model.transcribe('audio.mp3', regroup=False)
result1.split_by_punctuation(['.', '', '?', ''], True).split_by_gap(.5).merge_by_gap(.15).unlock_all_segments()
# result0 == result1
```

https://user-images.githubusercontent.com/28970749/218942942-060610e4-4c96-454d-b00a-8c9a41f4e7de.mp4

### Visualizing Suppression
- Requirement: [Pillow](https://github.com/python-pillow/Pillow) or [opencv-python](https://github.com/opencv/opencv-python)
#### Non-VAD Suppression
![image](./demo/novad.png)
```python
import stable_whisper
# regions on the waveform colored red is where it will be likely be suppressed and marked to as silent
# [q_levels=20] and [k_size=5] are defaults for non-VAD.
stable_whisper.visualize_suppression('audio.mp3', 'image.png', q_levels=20, k_size = 5)
```
#### VAD Suppression
![image](./demo/vad.png)
```python
# [vad_threshold=0.35] is defaults for VAD.
stable_whisper.visualize_suppression('audio.mp3', 'image.png', vad=True, vad_threshold=0.35)
```

### Encode Comparison
```python
# the above uses default settings on version 1.1 with large model
# sentence/phrase-level & word-level
stable_whisper.results_to_sentence_word_ass(results, 'audio.ass')
import stable_whisper

stable_whisper.encode_video_comparison(
'audio.mp3',
['audio_sub1.srt', 'audio_sub2.srt'],
output_videopath='audio.mp4',
labels=['Example 1', 'Example 2']
)
```
#### Additional Info
* Although timestamps are chronological, they can still very inaccurate depending on the model, audio, and parameters.
* To produce production ready word-level results, the model needs to be fine-tuned with high quality dataset of audio with word-level timestamp.



## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def read_me() -> str:
setup(
name="stable-ts",
version=version(),
description="Stabilizing timestamps of OpenAI's Whisper outputs down to word-level.",
description="Modifies OpenAI's Whisper to produce more reliable timestamps.",
long_description=read_me(),
long_description_content_type='text/markdown',
python_requires=">=3.7",
python_requires=">=3.8",
author="Jian",
url="https://github.com/jianfch/stable-ts",
license="MIT",
Expand All @@ -31,7 +31,7 @@ def read_me() -> str:
"more-itertools",
"transformers>=4.19.0",
"ffmpeg-python==0.2.0",
"openai-whisper==20230124"
"openai-whisper==20230308"
],
entry_points={
"console_scripts": ["stable-ts=stable_whisper.whisper_word_level:cli"],
Expand Down
6 changes: 4 additions & 2 deletions stable_whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .stabilization import *
from .text_output import *
from .whisper_word_level import *
from .result import *
from .text_output import *
from .video_output import *
from .stabilization import visualize_suppression
from ._version import __version__
2 changes: 1 addition & 1 deletion stable_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.0"
__version__ = "2.0.0"
77 changes: 1 addition & 76 deletions stable_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch
import torchaudio
from torch.nn.functional import interpolate, avg_pool1d, pad
import numpy as np


def voice_freq_filter(wf: (torch.Tensor, np.ndarray), sr: int,
Expand All @@ -19,80 +18,6 @@ def voice_freq_filter(wf: (torch.Tensor, np.ndarray), sr: int,
lower_freq)


def prep_wf_mask(wf: (torch.Tensor, np.ndarray),
output_size: int = None,
kernel_size: int = None) \
-> torch.Tensor:
"""
Preprocesses waveform to be processed into timestamp suppression mask.
"""
if isinstance(wf, np.ndarray):
wf = torch.from_numpy(wf).float()
else:
wf = wf.float()
assert wf.dim() < 3, f'waveform must be 1D or 2D, but got {wf.dim()}D'
wf = wf.abs()
if wf.dim() < 3:
unsqueezes = 3 - wf.dim()
wf = wf[[None] * unsqueezes]
else:
unsqueezes = 0
if output_size is not None:
wf = interpolate(wf,
size=output_size,
mode='linear',
align_corners=False)
p = kernel_size // 2 if kernel_size else 0
if not p or p >= wf.shape[-1]:
mask = wf.mul(255).round()
else:
assert kernel_size % 2, f'kernel_size must be odd but got {kernel_size}'
mask = avg_pool1d(pad(wf.mul(255).round(), (p, p), 'reflect'), kernel_size=kernel_size, stride=1).round()
if unsqueezes:
for _ in range(unsqueezes):
mask.squeeze_(0)
return mask


def remove_lower_quantile(prepped_mask: torch.Tensor,
upper_quantile: float = None,
lower_quantile: float = None,
lower_threshold: float = None) -> torch.Tensor:
"""
Removes lower quantile of amplitude from waveform image
"""
if upper_quantile is None:
upper_quantile = 0.85
if lower_quantile is None:
lower_quantile = 0.15
if lower_threshold is None:
lower_threshold = 0.15
prepped_mask = prepped_mask.clone()
mx = torch.quantile(prepped_mask, upper_quantile)
mn = torch.quantile(prepped_mask, lower_quantile)
mn_threshold = (mx - mn) * lower_threshold + mn
prepped_mask[prepped_mask < mn_threshold] = 0
return prepped_mask


def finalize_mask(prepped_mask: torch.Tensor,
suppress_middle=True,
max_index: (list, int) = None) -> torch.Tensor:
"""
Returns a PyTorch Tensor mask of sections with amplitude zero
"""

prepped_mask = prepped_mask.bool()

if not suppress_middle:
nonzero_indices = prepped_mask.nonzero().flatten()
prepped_mask[nonzero_indices[0]:nonzero_indices[-1] + 1] = True
if max_index is not None:
prepped_mask[max_index + 1:] = False

return ~prepped_mask


def load_demucs_model():
from demucs.pretrained import get_model_from_args
return get_model_from_args(type('args', (object,), dict(name='htdemucs', repo=None))).cpu().eval()
Expand Down
113 changes: 113 additions & 0 deletions stable_whisper/decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import TYPE_CHECKING, List, Union
from dataclasses import replace

import torch
import numpy as np

from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult


if TYPE_CHECKING:
from whisper.model import Whisper


def _suppress_ts(ts_logits: torch.Tensor, ts_token_mask: torch.Tensor = None):
if ts_token_mask is not None:
ts_logits[:, ts_token_mask] = -np.inf


# modified version of whisper.decoding.DecodingTask
class DecodingTaskStable(DecodingTask):

def __init__(self, *args, **kwargs):
self.ts_token_mask: torch.Tensor = kwargs.pop('ts_token_mask', None)
self.audio_features: torch.Tensor = kwargs.pop('audio_features', None)
super(DecodingTaskStable, self).__init__(*args, **kwargs)

def _get_audio_features(self, mel: torch.Tensor):
if self.audio_features is None:
audio_features = super()._get_audio_features(mel)
self.audio_features = audio_features.detach().clone()
return audio_features
return self.audio_features.clone()

# modified version of whisper.DecodingTask._main_loop
def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: torch.Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch

try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)

if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

# now we need to consider the logits at the last token only
logits = logits[:, -1]

# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)

# suppress timestamp tokens where the audio is silent so that decoder ignores those timestamps
_suppress_ts(logits[:, self.tokenizer.timestamp_begin:], self.ts_token_mask)

# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()

return tokens, sum_logprobs, no_speech_probs


# modified version of whisper.decoding.decode
@torch.no_grad()
def decode_stable(model: "Whisper",
mel: torch.Tensor,
options: DecodingOptions = DecodingOptions(),
ts_token_mask: torch.Tensor = None,
audio_features: torch.Tensor = None,
**kwargs, ) -> \
Union[DecodingResult, List[DecodingResult], tuple]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
The Whisper model modified instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
ts_token_mask: torch.Tensor
Mask for suppressing to timestamp token(s) for decoding
audio_features: torch.Tensor
reused audio_feature from encoder for fallback
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)

if kwargs:
options = replace(options, **kwargs)

task = DecodingTaskStable(model, options, ts_token_mask=ts_token_mask, audio_features=audio_features)
result = task.run(mel)

return result[0] if single else result, task.audio_features
Loading

0 comments on commit 2248087

Please sign in to comment.