Skip to content

Commit

Permalink
-fixed [demucs] for transcribe_minimal()
Browse files Browse the repository at this point in the history
-fixed [demucs] for `transcribe_minimal()` with [audio_type]='torch' and [model_sr]=whisper.audio.SAMPLE_RATE
-updated transcribe_minimal() to accept all options that both transcribe_any() and whisper.transcribe() accept
-fixed --debug not showing the first option
  • Loading branch information
jianfch committed Oct 4, 2023
1 parent 17df943 commit 857df9a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
5 changes: 5 additions & 0 deletions stable_whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,8 @@ def str_to_valid_type(val: str):

def get_func_parameters(func):
return inspect.signature(func).parameters.keys()


def isolate_useful_options(options: dict, method, pop: bool = False) -> dict:
_get = dict.pop if pop else dict.get
return {k: _get(options, k) for k in get_func_parameters(method) if k in options}
30 changes: 20 additions & 10 deletions stable_whisper/whisper_word_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .timing import add_word_timestamps_stable
from .stabilization import get_vad_silence_func, wav2mask, mask2timing, timing2mask
from .non_whisper import transcribe_any
from .utils import warn_compatibility_issues
from .utils import warn_compatibility_issues, isolate_useful_options
from .alignment import align

if TYPE_CHECKING:
Expand Down Expand Up @@ -684,13 +684,14 @@ def transcribe_minimal(
k_size: int = 5,
demucs: bool = False,
demucs_output: str = None,
demucs_options: dict = None,
vad: bool = False,
vad_threshold: float = 0.35,
vad_onnx: bool = False,
min_word_dur: float = 0.1,
only_voice_freq: bool = False,
only_ffmpeg: bool = False,
**kwargs) \
**options) \
-> WhisperResult:
"""
Transcribe audio using unmodified Whisper.transcribe().
Expand All @@ -701,7 +702,13 @@ def transcribe_minimal(
word_timestamps=word_timestamps,
verbose=verbose
)
inference_kwargs.update(kwargs)
extra_options = isolate_useful_options(options, transcribe_any, True)
if demucs:
if 'audio_type' not in extra_options:
extra_options['audio_type'] = 'torch'
if 'model_sr' not in extra_options:
extra_options['model_sr'] = SAMPLE_RATE
inference_kwargs.update(options)
return transcribe_any(
inference_func=whisper.transcribe,
audio=audio,
Expand All @@ -714,13 +721,15 @@ def transcribe_minimal(
k_size=k_size,
demucs=demucs,
demucs_output=demucs_output,
demucs_options=demucs_options,
vad=vad,
vad_threshold=vad_threshold,
vad_onnx=vad_onnx,
min_word_dur=min_word_dur,
only_voice_freq=only_voice_freq,
only_ffmpeg=only_ffmpeg,
force_order=True
force_order=True,
**extra_options
)


Expand Down Expand Up @@ -1176,7 +1185,7 @@ def make_parent(filepath: str):
def is_json(file: str):
return file.endswith(".json")

def call_method_with_options(method, options: dict, include_first: bool = False):
def call_method_with_options(method, options: dict, include_first: bool = True):
def val_to_str(val) -> str:
if isinstance(val, (np.ndarray, torch.Tensor)):
return f'{val.__class__}(shape:{list(val.shape)})'
Expand All @@ -1198,12 +1207,13 @@ def val_to_str(val) -> str:
for k, v in options.items()
if include_first or k != params[0]
)
print(f'{method.__qualname__}(\n{options_str}\n)')
if options_str:
options_str = f'\n{options_str}\n'
else:
print(options, params)
print(f'{method.__qualname__}({options_str})')
return method(**options)

def isolate_useful_options(options: dict, method) -> dict:
return {k: options.get(k) for k in get_func_parameters(method) if k in options}

def finalize_outputs(input_file: str, _output: str = None) -> List[str]:
_curr_output_formats = curr_output_formats.copy()
basename, ext = splitext(_output or input_file)
Expand Down Expand Up @@ -1290,7 +1300,7 @@ def cancel_overwrite():
dq=dq,
)
update_options_with_args('model_option', model_options)
model = call_method_with_options(load_model, model_options, include_first=True)
model = call_method_with_options(load_model, model_options)
if model_loading_str:
print(f'Loaded {model_loading_str} ')
args['regroup'] = False
Expand Down

0 comments on commit 857df9a

Please sign in to comment.