From 5ca7ca5c9e988b74d94d0c1699cd6da11e3273af Mon Sep 17 00:00:00 2001 From: Jian Date: Thu, 2 May 2024 23:30:01 -0400 Subject: [PATCH] updated silence suppression -changed silence suppression to treat multiple nonspeech sections within a word's duration as individual sections instead of treating them as one big section; thus less prone to prematurely cutting off a word due to gaps between its syllables -updated `use_word_position=True` to also take into account the index of each word instead of only whether it ends with default punctuations -changed `align()` to prioritize new timestamps if its speech percentage is within 10% rounding error of previous timestamps --- stable_whisper/_version.py | 2 +- stable_whisper/alignment.py | 4 +- stable_whisper/result.py | 2 +- stable_whisper/stabilization/__init__.py | 53 +++++++++++++++++------- 4 files changed, 41 insertions(+), 20 deletions(-) diff --git a/stable_whisper/_version.py b/stable_whisper/_version.py index 8f4a351..a6b62ff 100644 --- a/stable_whisper/_version.py +++ b/stable_whisper/_version.py @@ -1 +1 @@ -__version__ = "2.16.0" +__version__ = "2.17.0" diff --git a/stable_whisper/alignment.py b/stable_whisper/alignment.py index 322d11b..483bbba 100644 --- a/stable_whisper/alignment.py +++ b/stable_whisper/alignment.py @@ -440,8 +440,8 @@ def speech_percentage(_word: dict, _mask: torch.Tensor, _offset: float): return 1 - _mask[s:e].float().mean().nan_to_num() def is_new_better(w0, m0, o0, w1, m1, o1): - speech0 = speech_percentage(w0, m0, o0) - speech1 = speech_percentage(w1, m1, o1) + speech0 = speech_percentage(w0, m0, o0).round(decimals=1) + speech1 = speech_percentage(w1, m1, o1).round(decimals=1) return speech0 >= speech1 or w0['probability'] >= w1['probability'] with tqdm(total=initial_duration, unit='sec', disable=verbose is not False, desc='Align') as tqdm_pbar: diff --git a/stable_whisper/result.py b/stable_whisper/result.py index 8f6e99f..872bf80 100644 --- a/stable_whisper/result.py +++ b/stable_whisper/result.py @@ -648,7 +648,7 @@ def suppress_silence(self, words = self.words if word_level or len(self.words) == 1 else [self.words[0], self.words[-1]] for i, w in enumerate(words, 1): if use_word_position: - keep_end = w.word[-1] not in ending_punctuations + keep_end = not (w.word[-1] in ending_punctuations or i == len(words)) else: keep_end = None w.suppress_silence(silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end) diff --git a/stable_whisper/stabilization/__init__.py b/stable_whisper/stabilization/__init__.py index 68c5db1..b3e2d1a 100644 --- a/stable_whisper/stabilization/__init__.py +++ b/stable_whisper/stabilization/__init__.py @@ -257,6 +257,8 @@ def predict_with_samples( offset: Optional[float] = None ) -> dict: if self.get_mask: + if extra_len := audio.shape[-1] % N_SAMPLES_PER_TOKEN: + audio = torch.nn.functional.pad(audio, (0, N_SAMPLES_PER_TOKEN - extra_len)) mask = torch.all(audio.reshape(-1, N_SAMPLES_PER_TOKEN), dim=-1) min_unit_per_word = self.min_frames_per_word else: @@ -308,7 +310,7 @@ def suppress_silence( if isinstance(silent_ends, list): silent_ends = np.array(silent_ends) - start_overlaps = np.all( + start_overlaps = (keep_end is None or keep_end) and np.all( (silent_starts <= result_obj.start, result_obj.start < silent_ends, silent_ends <= result_obj.end), axis=0 ).nonzero()[0].tolist() @@ -318,7 +320,7 @@ def suppress_silence( if (result_obj.end - result_obj.start) <= min_word_dur: return - end_overlaps = np.all( + end_overlaps = not keep_end and np.all( (result_obj.start <= silent_starts, silent_starts < result_obj.end, result_obj.end <= silent_ends), axis=0 ).nonzero()[0].tolist() @@ -335,25 +337,44 @@ def suppress_silence( ).nonzero()[0].tolist() if len(matches) != 1: return - silence_start = silent_starts[matches[0]] - silence_end = silent_ends[matches[0]] - start_extra = silence_start - result_obj.start - end_extra = result_obj.end - silence_end - silent_duration = silence_end - silence_start - start_within_error = (start_extra / silent_duration) <= nonspeech_error - end_within_error = (end_extra / silent_duration) <= nonspeech_error - if keep_end is None: - keep_end = start_extra <= end_extra - within_error = start_within_error if keep_end else end_within_error - else: - within_error = start_within_error or end_within_error - if within_error: - if keep_end: + def silence_errors(silence_start, silence_end): + start_extra = silence_start - result_obj.start + end_extra = result_obj.end - silence_end + silent_duration = silence_end - silence_start + start_error = start_extra / silent_duration + end_error = end_extra / silent_duration + return start_error, end_error + + def _adjust(silence_start, silence_end, errors=None): + if not errors: + errors = silence_errors(silence_start, silence_end) + _keep_end = keep_end + start_within_error = errors[0] <= nonspeech_error + end_within_error = errors[1] <= nonspeech_error + if _keep_end is None: + _keep_end = errors[0] <= errors[1] + if not (start_within_error or end_within_error): + return + if _keep_end: result_obj.start = min(silence_end, round(result_obj.end - min_word_dur, 3)) else: result_obj.end = max(silence_start, round(result_obj.start + min_word_dur, 3)) + max_i = len(matches) - 1 + for i in range(len(matches)): + error = None + if i == max_i: + idx = 0 + elif keep_end is None: + error0 = silence_errors(silent_starts[matches[0]], silent_ends[matches[0]]) + error1 = silence_errors(silent_starts[matches[-1]], silent_ends[matches[-1]]) + idx, error = (0, error0) if min(error0) <= min(error1) else (-1, error1) + else: + idx = 0 if keep_end else -1 + idx = matches.pop(idx) + _adjust(silent_starts[idx], silent_ends[idx], error) + def visualize_suppression( audio: Union[torch.Tensor, np.ndarray, str, bytes],