Skip to content

Commit

Permalink
added support for DeepFilterNet (#303)
Browse files Browse the repository at this point in the history
-added DeepFilterNet (https://github.com/Rikorose/DeepFilterNet) to list of of supported denoisers; use with `denoiser="dfnet"`; noise attenuation set with `denoiser_options`, e.g. suppress 12dB with `denoiser_options=dict(atten_lim_db=12)`
-added Whisper on Hugging Face Transformers to CLI
-fixed `WhisperHF.transcribe()` unable to load when audio is a URL or file of formats unsupported by `torchaudio.load()` backend
-fixed CLI throwing OSError when input is a URL and --output is not specified
  • Loading branch information
jianfch committed Feb 2, 2024
1 parent 9197b5c commit 3fafd04
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 31 deletions.
2 changes: 1 addition & 1 deletion stable_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.15.3"
__version__ = "2.15.4"
10 changes: 6 additions & 4 deletions stable_whisper/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
is_ytdlp_available, load_source, load_audio, voice_freq_filter, get_samplerate, get_metadata
)
from .demucs import is_demucs_available, load_demucs_model, demucs_audio
from .dfnet import is_dfnet_available, load_dfnet_model, dfnet_audio
from .output import save_audio_tensor
from ..utils import update_options

from whisper.audio import SAMPLE_RATE


SUPPORTED_DENOISERS = {
'demucs': {'run': demucs_audio, 'load': load_demucs_model, 'access': is_demucs_available}
'demucs': {'run': demucs_audio, 'load': load_demucs_model, 'access': is_demucs_available},
'dfnet': {'run': dfnet_audio, 'load': load_dfnet_model, 'access': is_dfnet_available}
}


Expand Down Expand Up @@ -248,7 +250,7 @@ def _load_denoise_model(self):
if not self._denoiser:
return None, None
model = get_denoiser_func(self._denoiser, 'load')(True)
length = int(model.segment * self._sr)
length = int(getattr(model, 'segment', 5) * self._sr)
return model, length

def check_min_chunk_requirement(self):
Expand Down Expand Up @@ -384,11 +386,11 @@ def _get_prep_func(self):
if self._final_save_path:
warnings.warn('Both ``save_path`` in AudioLoad and ``denoiser_options`` were specified, '
'but only the final audio will be saved for ``stream=True`` in either case. '
'So AudioLoad will take priority over ``denoiser_options`` for ``save_path``.',
'``denoiser_options`` will be prioritized for ``save_path``.',
stacklevel=2)
self._denoised_save_path = None
else:
self._final_save_path = self._denoised_save_path
self._denoised_save_path = None

denoise_func = get_denoiser_func(self._denoiser, 'run')

Expand Down
8 changes: 2 additions & 6 deletions stable_whisper/audio/demucs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from .utils import load_audio
from ..audio.utils import resample
from ..default import cached_model_instances


Expand Down Expand Up @@ -163,12 +165,10 @@ def demucs_audio(
model = load_demucs_model()

if isinstance(audio, (str, bytes)):
from .utils import load_audio
audio = torch.from_numpy(load_audio(audio, model.samplerate))
elif input_sr != model.samplerate:
if input_sr is None:
raise ValueError('No ``input_sr`` specified for audio tensor.')
from ..audio.utils import resample
audio = resample(audio, input_sr, model.samplerate)
audio_dims = audio.dim()
assert audio_dims <= 3
Expand All @@ -177,9 +177,6 @@ def demucs_audio(
if audio.shape[-2] == 1:
audio = audio.repeat_interleave(2, -2)

if 'mix' in demucs_options:
audio = demucs_options.pop('mix')

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -200,7 +197,6 @@ def demucs_audio(
torch.cuda.empty_cache()

if output_sr is not None and model.samplerate != output_sr:
from ..audio.utils import resample
vocals = resample(vocals, model.samplerate, output_sr)

if save_path is not None:
Expand Down
85 changes: 85 additions & 0 deletions stable_whisper/audio/dfnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Union, Optional

import torch

from .utils import load_audio
from ..audio.utils import resample
from ..default import cached_model_instances


def is_dfnet_available():
from importlib.util import find_spec
if find_spec('df') is None:
raise ModuleNotFoundError("Please install DeepFilterNet; "
"'pip install -U deepfilternet' or "
"follow installation instructions at https://github.com/Rikorose/DeepFilterNet")


def load_dfnet_model(cache: bool = True, **kwargs):
model_name = 'dfnet'
_model_cache = cached_model_instances['dfnet'] if cache else None
if _model_cache is not None and _model_cache[model_name] is not None:
return _model_cache[model_name]
is_dfnet_available()
from types import MethodType
from df.enhance import init_df, enhance
model, df_state, _ = init_df(**kwargs)
model.df_state = df_state

def enhance_wrapper(_model, audio, **enhance_kwargs):
return enhance(model=_model, df_state=_model.df_state, audio=audio, **enhance_kwargs)

model.enhance = MethodType(enhance_wrapper, model)
model.samplerate = df_state.sr()
if _model_cache is not None:
_model_cache[model_name] = model
return model


def dfnet_audio(
audio: Union[torch.Tensor, str, bytes],
input_sr: int = None,
output_sr: int = None,
model=None,
device=None,
verbose: bool = True,
save_path: Optional[Union[str, callable]] = None,
**dfnet_options
) -> torch.Tensor:
"""
Remove noise from ``audio`` with DeepFilterNet.
Official repo: https://github.com/Rikorose/DeepFilterNet.
"""
if model is None:
model = load_dfnet_model()
if isinstance(audio, (str, bytes)):
audio = torch.from_numpy(load_audio(audio, model.samplerate))
elif input_sr != model.samplerate:
if input_sr is None:
raise ValueError('No ``input_sr`` specified for audio tensor.')
audio = resample(audio, input_sr, model.samplerate)
audio_dims = audio.dim()
assert audio_dims <= 2
if dims_missing := 2 - audio_dims:
audio = audio[[None]*dims_missing]
if audio.shape[-2] == 1:
audio = audio.repeat_interleave(2, -2)

dfnet_options.pop('progress', None) # not implemented
denoised_audio = model.enhance(audio=audio, **dfnet_options).mean(dim=0)

if device != 'cpu':
torch.cuda.empty_cache()

if output_sr is not None and model.samplerate != output_sr:
denoised_audio = resample(denoised_audio, model.samplerate, output_sr)

if save_path is not None:
if isinstance(save_path, str):
from .output import save_audio_tensor
save_audio_tensor(denoised_audio, save_path, output_sr or model.samplerate, verbose=verbose)
else:
save_path(denoised_audio)

return denoised_audio
3 changes: 3 additions & 0 deletions stable_whisper/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
silero_vad={
True: None,
False: None
},
dfnet={
'dfnet': None
}
)

Expand Down
54 changes: 38 additions & 16 deletions stable_whisper/whisper_word_level/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def update_options_with_args(arg_key: Union[str, list], options: Optional[dict]
return extra_options
options.update(extra_options)

def url_to_path(url: str):
if '://' in url:
from urllib.parse import urlparse
return urlparse(url).path.strip('/')
return url

OUTPUT_FORMATS_METHODS = {
"srt": "to_srt_vtt",
"ass": "to_ass",
Expand Down Expand Up @@ -323,13 +329,21 @@ def update_options_with_args(arg_key: Union[str, list], options: Optional[dict]
parser.add_argument('--faster_whisper', '-fw', action='store_true',
help='whether to use faster-whisper (https://github.com/guillaumekln/faster-whisper); '
'note: some features may not be available')
parser.add_argument('--huggingface_whisper', '-hw', action='store_true',
help='whether to run Whisper on Hugging Face Transformers for more speed than faster-whisper'
' and even more speed with Flash Attention enabled on supported GPUs'
'(https://huggingface.co/openai/whisper-large-v3); '
'note: some features may not be available')

args = parser.parse_args().__dict__
debug = args.pop('debug')
if not args['language'] and (args['align'] or args['locate']):
raise ValueError('langauge is required for --align / --locate')

is_faster_whisper = args.pop('faster_whisper')
is_hf_whisper = args.pop('huggingface_whisper')
assert not (is_faster_whisper and is_hf_whisper), f'--huggingface_whisper cannot be used with --faster_whisper'
is_original_whisper = not (is_faster_whisper or is_hf_whisper)
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
inputs: List[Union[str, torch.Tensor]] = args.pop("inputs")
Expand Down Expand Up @@ -357,13 +371,29 @@ def update_options_with_args(arg_key: Union[str, list], options: Optional[dict]
if args['reverse_text']:
args['reverse_text'] = (args.get('prepend_punctuations'), args.get('append_punctuations'))

if is_faster_whisper:
if is_original_whisper:
model_type_name = 'Whisper'
from .original_whisper import load_model as load_model_func
model_name_kwarg = dict(name=model_name)
else:
if is_faster_whisper:
model_type_name = 'Faster-Whisper'
from .faster_whisper import load_faster_whisper as load_model_func
model_name_kwarg = dict(model_size_or_path=model_name)
else:
model_type_name = 'Hugging Face Whisper'
from .hf_whisper import load_hf_whisper as load_model_func
model_name_kwarg = dict(model_name=model_name)

if args.get('transcribe_method') == 'transcribe_minimal':
warnings.warn('Faster-Whisper models already run on a version of transcribe_minimal. '
warnings.warn(f'{model_type_name} models already run on a version of transcribe_minimal. '
'--transcribe_method "transcribe_minimal" will be ignored.')
args['transcribe_method'] = 'transcribe'
if args.get('refine'):
raise NotImplementedError('--refine is not supported for Faster-Whisper models.')
raise NotImplementedError(f'--refine is not supported for {model_type_name} models.')
if strings_to_locate:
raise NotImplementedError(f'--locate is not supported for {model_type_name} models.')
if is_faster_whisper:
args['transcribe_method'] = 'transcribe_stable'

if regroup:
try:
Expand Down Expand Up @@ -437,7 +467,7 @@ def val_to_str(val) -> str:

def finalize_outputs(input_file: str, _output: str = None, _alignment: str = None) -> List[str]:
_curr_output_formats = curr_output_formats.copy()
basename, ext = splitext(_output or input_file)
basename, ext = splitext(_output or url_to_path(input_file))
ext = ext[1:]
if _output:
if ext.lower() in OUTPUT_FORMATS:
Expand Down Expand Up @@ -498,8 +528,7 @@ def finalize_outputs(input_file: str, _output: str = None, _alignment: str = Non
if show_curr_task:
model_from_str = '' if model_dir is None else f' from {model_dir}'
model_loading_str = (
f'{"Faster-Whisper" if is_faster_whisper else "Whisper"} '
f'{model_name} model {model_from_str}'
f'{model_type_name} {model_name} model {model_from_str}'
)
print(f'Loading {model_loading_str}\r', end='\n' if debug else '')
else:
Expand All @@ -512,16 +541,11 @@ def _load_model():
nonlocal model
if model is None:
model_options = dict(
name=model_name,
model_size_or_path=model_name,
device=args.get('device'),
download_root=model_dir,
dq=dq,
)
if is_faster_whisper:
from .faster_whisper import load_faster_whisper as load_model_func
else:
from .original_whisper import load_model as load_model_func
model_options.update(model_name_kwarg)
model_options = isolate_useful_options(model_options, load_model_func)
update_options_with_args('model_option', model_options)
model = call_method_with_options(load_model_func, model_options)
Expand Down Expand Up @@ -549,15 +573,13 @@ def _load_model():
text = f.read()
args['text'] = text
transcribe_method = 'align'
if is_faster_whisper and transcribe_method == 'transcribe':
transcribe_method = 'transcribe_stable'
if strings_to_locate and (text := strings_to_locate[i]):
args['text'] = text
transcribe_method = 'locate'
skip_output = args['verbose'] = True
transcribe_method = getattr(model, transcribe_method)
transcribe_options = isolate_useful_options(args, transcribe_method)
if not text:
if not text and not is_hf_whisper:
decoding_options = (
isolate_useful_options(args, model.transcribe if is_faster_whisper else DecodingOptions)
)
Expand Down
11 changes: 7 additions & 4 deletions stable_whisper/whisper_word_level/hf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from ..audio import convert_demucs_kwargs
from ..audio import convert_demucs_kwargs, prep_audio
from ..non_whisper import transcribe_any


Expand Down Expand Up @@ -122,9 +122,12 @@ def transcribe(
demucs=options.pop('demucs', None), demucs_options=options.pop('demucs_options', None)
)

if not isinstance(audio, (str, bytes)):
if 'input_sr' not in options:
options['input_sr'] = self.sampling_rate
if isinstance(audio, (str, bytes)):
audio = prep_audio(audio, sr=self.sampling_rate).numpy()
options['input_sr'] = self.sampling_rate

if 'input_sr' not in options:
options['input_sr'] = self.sampling_rate

if denoiser or only_voice_freq:
if 'audio_type' not in options:
Expand Down

0 comments on commit 3fafd04

Please sign in to comment.