Skip to content

Commit

Permalink
updated silence suppression
Browse files Browse the repository at this point in the history
-changed silence suppression to treat multiple nonspeech sections within a word's duration as individual sections instead of treating them as one big section; thus less prone to prematurely cutting off a word due to gaps between its syllables
-updated `use_word_position=True` to also take into account the index of each word instead of only whether it ends with default punctuations
-changed `align()` to prioritize new timestamps if its speech percentage is within 10% rounding error of previous timestamps
  • Loading branch information
jianfch committed May 3, 2024
1 parent 0546d76 commit 5ca7ca5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 20 deletions.
2 changes: 1 addition & 1 deletion stable_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.16.0"
__version__ = "2.17.0"
4 changes: 2 additions & 2 deletions stable_whisper/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ def speech_percentage(_word: dict, _mask: torch.Tensor, _offset: float):
return 1 - _mask[s:e].float().mean().nan_to_num()

def is_new_better(w0, m0, o0, w1, m1, o1):
speech0 = speech_percentage(w0, m0, o0)
speech1 = speech_percentage(w1, m1, o1)
speech0 = speech_percentage(w0, m0, o0).round(decimals=1)
speech1 = speech_percentage(w1, m1, o1).round(decimals=1)
return speech0 >= speech1 or w0['probability'] >= w1['probability']

with tqdm(total=initial_duration, unit='sec', disable=verbose is not False, desc='Align') as tqdm_pbar:
Expand Down
2 changes: 1 addition & 1 deletion stable_whisper/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def suppress_silence(self,
words = self.words if word_level or len(self.words) == 1 else [self.words[0], self.words[-1]]
for i, w in enumerate(words, 1):
if use_word_position:
keep_end = w.word[-1] not in ending_punctuations
keep_end = not (w.word[-1] in ending_punctuations or i == len(words))
else:
keep_end = None
w.suppress_silence(silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end)
Expand Down
53 changes: 37 additions & 16 deletions stable_whisper/stabilization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def predict_with_samples(
offset: Optional[float] = None
) -> dict:
if self.get_mask:
if extra_len := audio.shape[-1] % N_SAMPLES_PER_TOKEN:
audio = torch.nn.functional.pad(audio, (0, N_SAMPLES_PER_TOKEN - extra_len))
mask = torch.all(audio.reshape(-1, N_SAMPLES_PER_TOKEN), dim=-1)
min_unit_per_word = self.min_frames_per_word
else:
Expand Down Expand Up @@ -308,7 +310,7 @@ def suppress_silence(
if isinstance(silent_ends, list):
silent_ends = np.array(silent_ends)

start_overlaps = np.all(
start_overlaps = (keep_end is None or keep_end) and np.all(
(silent_starts <= result_obj.start, result_obj.start < silent_ends, silent_ends <= result_obj.end),
axis=0
).nonzero()[0].tolist()
Expand All @@ -318,7 +320,7 @@ def suppress_silence(
if (result_obj.end - result_obj.start) <= min_word_dur:
return

end_overlaps = np.all(
end_overlaps = not keep_end and np.all(
(result_obj.start <= silent_starts, silent_starts < result_obj.end, result_obj.end <= silent_ends),
axis=0
).nonzero()[0].tolist()
Expand All @@ -335,25 +337,44 @@ def suppress_silence(
).nonzero()[0].tolist()
if len(matches) != 1:
return
silence_start = silent_starts[matches[0]]
silence_end = silent_ends[matches[0]]
start_extra = silence_start - result_obj.start
end_extra = result_obj.end - silence_end
silent_duration = silence_end - silence_start
start_within_error = (start_extra / silent_duration) <= nonspeech_error
end_within_error = (end_extra / silent_duration) <= nonspeech_error
if keep_end is None:
keep_end = start_extra <= end_extra
within_error = start_within_error if keep_end else end_within_error
else:
within_error = start_within_error or end_within_error

if within_error:
if keep_end:
def silence_errors(silence_start, silence_end):
start_extra = silence_start - result_obj.start
end_extra = result_obj.end - silence_end
silent_duration = silence_end - silence_start
start_error = start_extra / silent_duration
end_error = end_extra / silent_duration
return start_error, end_error

def _adjust(silence_start, silence_end, errors=None):
if not errors:
errors = silence_errors(silence_start, silence_end)
_keep_end = keep_end
start_within_error = errors[0] <= nonspeech_error
end_within_error = errors[1] <= nonspeech_error
if _keep_end is None:
_keep_end = errors[0] <= errors[1]
if not (start_within_error or end_within_error):
return
if _keep_end:
result_obj.start = min(silence_end, round(result_obj.end - min_word_dur, 3))
else:
result_obj.end = max(silence_start, round(result_obj.start + min_word_dur, 3))

max_i = len(matches) - 1
for i in range(len(matches)):
error = None
if i == max_i:
idx = 0
elif keep_end is None:
error0 = silence_errors(silent_starts[matches[0]], silent_ends[matches[0]])
error1 = silence_errors(silent_starts[matches[-1]], silent_ends[matches[-1]])
idx, error = (0, error0) if min(error0) <= min(error1) else (-1, error1)
else:
idx = 0 if keep_end else -1
idx = matches.pop(idx)
_adjust(silent_starts[idx], silent_ends[idx], error)


def visualize_suppression(
audio: Union[torch.Tensor, np.ndarray, str, bytes],
Expand Down

0 comments on commit 5ca7ca5

Please sign in to comment.