Skip to content

Commit

Permalink
Merge pull request #646 from santialferez/diarize-patch-1
Browse files Browse the repository at this point in the history
Update pyannote to v3.1.1 to fix a diarization problem (and diarize.py)
  • Loading branch information
m-bain authored Jan 3, 2024
2 parents 06e30b2 + 8ae6416 commit 59962a7
Show file tree
Hide file tree
Showing 15 changed files with 2,366 additions and 3 deletions.
227 changes: 227 additions & 0 deletions build/lib/whisperx/SubtitlesProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import math
from conjunctions import get_conjunctions, get_comma
from typing import TextIO

def normal_round(n):
if n - math.floor(n) < 0.5:
return math.floor(n)
return math.ceil(n)


def format_timestamp(seconds: float, is_vtt: bool = False):

assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)

hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000

minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000

seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000

separator = '.' if is_vtt else ','

hours_marker = f"{hours:02d}:"
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
)



class SubtitlesProcessor:
def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
self.comma = get_comma(lang)
self.conjunctions = set(get_conjunctions(lang))
self.segments = segments
self.lang = lang
self.max_line_length = max_line_length
self.min_char_length_splitter = min_char_length_splitter
self.is_vtt = is_vtt
complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
if self.lang in complex_script_languages:
self.max_line_length = 30
self.min_char_length_splitter = 20

def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
k = 0.25
has_prev_end = i > 0 and 'end' in words[i - 1]
has_next_start = i < len(words) - 1 and 'start' in words[i + 1]

if has_prev_end:
words[i]['start'] = words[i - 1]['end']
if has_next_start:
words[i]['end'] = words[i + 1]['start']
else:
if next_segment_start_time:
words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
else:
words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k

elif has_next_start:
words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
words[i]['end'] = words[i + 1]['start']

else:
if next_segment_start_time:
words[i]['start'] = next_segment_start_time - 1
words[i]['end'] = next_segment_start_time - 0.5
else:
words[i]['start'] = 0
words[i]['end'] = 0



def process_segments(self, advanced_splitting=True):
subtitles = []
for i, segment in enumerate(self.segments):
next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None

if advanced_splitting:

split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
else:
words = segment['words']
for i, word in enumerate(words):
if 'start' not in word or 'end' not in word:
self.estimate_timestamp_for_word(words, i, next_segment_start_time)

subtitles.append({
'start': segment['start'],
'end': segment['end'],
'text': segment['text']
})

return subtitles

def determine_advanced_split_points(self, segment, next_segment_start_time=None):
split_points = []
last_split_point = 0
char_count = 0

words = segment.get('words', segment['text'].split())
add_space = 0 if self.lang in ['zh', 'ja'] else 1

total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
char_count_after = total_char_count

for i, word in enumerate(words):
word_text = word['word'] if isinstance(word, dict) else word
word_length = len(word_text) + add_space
char_count += word_length
char_count_after -= word_length

char_count_before = char_count - word_length

if isinstance(word, dict) and ('start' not in word or 'end' not in word):
self.estimate_timestamp_for_word(words, i, next_segment_start_time)

if char_count >= self.max_line_length:
midpoint = normal_round((last_split_point + i) / 2)
if char_count_before >= self.min_char_length_splitter:
split_points.append(midpoint)
last_split_point = midpoint + 1
char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))

elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
split_points.append(i)
last_split_point = i + 1
char_count = 0

elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
split_points.append(i - 1)
last_split_point = i
char_count = word_length

return split_points


def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
subtitles = []

words = segment.get('words', segment['text'].split())
total_word_count = len(words)
total_time = segment['end'] - segment['start']
elapsed_time = segment['start']
prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
start_idx = 0
for split_point in split_points:

fragment_words = words[start_idx:split_point + 1]
current_word_count = len(fragment_words)


if isinstance(fragment_words[0], dict):
start_time = fragment_words[0]['start']
end_time = fragment_words[-1]['end']
next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
end_time = next_start_time_for_word
else:
fragment = prefix.join(fragment_words).strip()
current_duration = (current_word_count / total_word_count) * total_time
start_time = elapsed_time
end_time = elapsed_time + current_duration
elapsed_time += current_duration


subtitles.append({
'start': start_time,
'end': end_time,
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
})

start_idx = split_point + 1

# Handle the last fragment
if start_idx < len(words):
fragment_words = words[start_idx:]
current_word_count = len(fragment_words)

if isinstance(fragment_words[0], dict):
start_time = fragment_words[0]['start']
end_time = fragment_words[-1]['end']
else:
fragment = prefix.join(fragment_words).strip()
current_duration = (current_word_count / total_word_count) * total_time
start_time = elapsed_time
end_time = elapsed_time + current_duration

if next_start_time and (next_start_time - end_time) <= 0.8:
end_time = next_start_time

subtitles.append({
'start': start_time,
'end': end_time if end_time is not None else segment['end'],
'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
})

return subtitles



def save(self, filename="subtitles.srt", advanced_splitting=True):

subtitles = self.process_segments(advanced_splitting)

def write_subtitle(file, idx, start_time, end_time, text):

file.write(f"{idx}\n")
file.write(f"{start_time} --> {end_time}\n")
file.write(text + "\n\n")

with open(filename, 'w', encoding='utf-8') as file:
if self.is_vtt:
file.write("WEBVTT\n\n")

if advanced_splitting:
for idx, subtitle in enumerate(subtitles, 1):
start_time = format_timestamp(subtitle['start'], self.is_vtt)
end_time = format_timestamp(subtitle['end'], self.is_vtt)
text = subtitle['text'].strip()
write_subtitle(file, idx, start_time, end_time, text)

return len(subtitles)
4 changes: 4 additions & 0 deletions build/lib/whisperx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .transcribe import load_model
from .alignment import load_align_model, align
from .audio import load_audio
from .diarize import assign_word_speakers, DiarizationPipeline
4 changes: 4 additions & 0 deletions build/lib/whisperx/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .transcribe import cli


cli()
Loading

0 comments on commit 59962a7

Please sign in to comment.