Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve audio splitting in dataset generation #419

Draft
wants to merge 4 commits into
base: alltalkbeta
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 110 additions & 149 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,11 +1164,11 @@ def format_audio_list(
debug_print("Processing with VAD", "AUDIO")
# Get VAD segments with resampling
vad_segments = process_audio_with_vad(
wav, sr, vad_model, get_speech_timestamps)
wav, sr, vad_model, get_speech_timestamps, max_duration=max_duration)

# Group short segments that are close together
merged_segments = merge_short_segments(
vad_segments, min_duration, max_gap=0.3)
vad_segments, min_duration, sr=sr, max_gap=0.3)
debug_print(
f"Merged {len(vad_segments)-len(merged_segments)} short segments",
"SEGMENTS")
Expand Down Expand Up @@ -1515,7 +1515,7 @@ def _create_bpe_tokenizer(bpe_whisper_words, bpe_out_path, bpe_base_path):
raise


def merge_short_segments(segments, min_duration, max_gap=0.5):
def merge_short_segments(segments, min_duration, sr, max_gap=0.5):
"""
More aggressive merge strategy for short segments
- Increases max_gap to 0.5s (from 0.3s)
Expand All @@ -1527,7 +1527,7 @@ def merge_short_segments(segments, min_duration, max_gap=0.5):

merged = []
current_group = []
target_duration = (min_duration + 10.0) / 2 # Target middle of range
target_duration = sr * (min_duration + 10.0) / 2 # Target middle of range

for i, segment in enumerate(segments):
current_duration = sum(s["end"] - s["start"]
Expand Down Expand Up @@ -1566,16 +1566,15 @@ def merge_short_segments(segments, min_duration, max_gap=0.5):
merged.append(merged_segment)

debug_print(
f"Merged {len(segments) - len(merged)} segments into {len(merged)} segments with mid-range preference",
f"Merged {len(segments) - len(merged)} segments with mid-range preference, for a new total of {len(merged)}",
"SEGMENTS"
)
return merged


def extend_segment(wav, start, end, sr, min_duration, context_window=1.0):
def extend_segment(wav, start, end, sr, min_duration):
"""
Improved segment extension with better context handling
- Adds context_window parameter for smoother extensions
- More balanced extension on both sides
- Checks audio content when extending
"""
Expand All @@ -1585,22 +1584,24 @@ def extend_segment(wav, start, end, sr, min_duration, context_window=1.0):

samples_needed = int((min_duration - current_duration) * sr)

# Check if we have enough samples in the file
if (samples_needed + current_duration * sr) > wav.size(-1):
# Pad the file symmetrically with zeroes
padding_amount = (samples_needed + current_duration * sr - wav.size(-1))
padding = torch.zeros(int(padding_amount // 2 + 1), dtype=wav.dtype)
return torch.cat([padding, wav, padding], dim=-1)

# Try to extend equally on both sides
extend_left = samples_needed // 2
extend_right = samples_needed - extend_left
new_start = max(0, start - extend_left)

# Add some context window
context_samples = int(context_window * sr)
new_start = max(0, start - extend_left - context_samples)
new_end = min(wav.size(-1), end + extend_right + context_samples)
# If there weren't enough samples on the left, extend more on the right side
extend_right = samples_needed - (start - new_start)
new_end = min(wav.size(-1), end + extend_right)

# Check if we got enough duration
if (new_end - new_start) / sr < min_duration:
# If still too short, try to compensate from the other side
if new_start == 0:
new_end = min(wav.size(-1), end + samples_needed + context_samples)
elif new_end == wav.size(-1):
new_start = max(0, start - samples_needed - context_samples)
# If there weren't enough samples on the right, extend more on the left side
extend_left = samples_needed - (new_end - end)
new_start = max(0, start - extend_left)

debug_print(
f"Extended segment from {current_duration:.2f}s to {(new_end - new_start) / sr:.2f}s",
Expand Down Expand Up @@ -1699,151 +1700,110 @@ def create_dataset_splits(df, eval_percentage, random_seed=42):


def save_audio_segment(
sas_audio,
sas_sr,
sas_start_time,
sas_end_time,
sas_sentence,
sas_audio_file_name_without_ext,
sas_segment_idx,
sas_speaker_name,
sas_audio_folder,
sas_metadata,
sas_max_duration,
_sas_buffer,
sas_too_long_files,
sas_target_language,
audio,
sr,
start_time,
end_time,
transcription,
audio_file_name_without_ext,
segment_idx,
speaker_name,
audio_folder,
metadata,
target_language,
):
"""Helper function to save audio segments and update metadata"""
sas_sentence = sas_sentence.strip()
sas_sentence = multilingual_cleaners(sas_sentence, sas_target_language)
sas_audio_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav"
transcription = transcription.strip()
sentence = multilingual_cleaners(sentence, target_language)
audio_file_name = f"{audio_file_name_without_ext}_{str(segment_idx).zfill(8)}.wav"

sas_absolute_path = os.path.join(sas_audio_folder, sas_audio_file_name)
os.makedirs(os.path.dirname(sas_absolute_path), exist_ok=True)
absolute_path = os.path.join(audio_folder, audio_file_name)
os.makedirs(os.path.dirname(absolute_path), exist_ok=True)

# Extract audio segment
sas_audio_start = int(sas_sr * sas_start_time)
sas_audio_end = int(sas_sr * sas_end_time)
sas_audio_segment = sas_audio[sas_audio_start:sas_audio_end].unsqueeze(0)

# Handle long audio segments
if sas_audio_segment.size(-1) > sas_max_duration * sas_sr:
sas_too_long_files.append(
(sas_audio_file_name, sas_audio_segment.size(-1) / sas_sr))

while sas_audio_segment.size(-1) > sas_max_duration * sas_sr:
sas_split_audio = sas_audio_segment[:, : int(
sas_max_duration * sas_sr)]
sas_audio_segment = sas_audio_segment[:, int(
sas_max_duration * sas_sr):]
sas_split_file_name = f"{sas_audio_file_name_without_ext}_{str(sas_segment_idx).zfill(8)}.wav"
sas_split_relative_path = os.path.join(sas_split_file_name)
sas_split_absolute_path = os.path.normpath(
os.path.join(sas_audio_folder, sas_split_relative_path))

os.makedirs(
os.path.dirname(sas_split_absolute_path),
exist_ok=True)
torchaudio.save(str(sas_split_absolute_path), sas_split_audio, sas_sr)

sas_metadata["audio_file"].append(
f"wavs/{sas_split_relative_path}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
sas_segment_idx += 1
audio_start = int(sr * start_time)
audio_end = int(sr * end_time)
audio_segment = audio[audio_start:audio_end].unsqueeze(0)

# Only save if segment is at least 1 second
if sas_audio_segment.size(-1) >= sas_sr:
torchaudio.save(str(sas_absolute_path), sas_audio_segment, sas_sr)
sas_metadata["audio_file"].append(f"wavs/{sas_audio_file_name}")
sas_metadata["text"].append(sas_sentence)
sas_metadata["speaker_name"].append(sas_speaker_name)
if audio_segment.size(-1) >= sr:
torchaudio.save(str(absolute_path), audio_segment, sr)
metadata["audio_file"].append(f"wavs/{audio_file_name}")
metadata["text"].append(sentence)
metadata["speaker_name"].append(speaker_name)


def process_transcription_result(
ptr_result,
ptr_audio,
ptr_sr,
ptr_segment_idx,
ptr_audio_file_name_without_ext,
ptr_metadata,
ptr_whisper_words,
ptr_max_duration,
ptr_buffer,
ptr_speaker_name,
ptr_audio_folder,
ptr_too_long_files,
ptr_create_bpe_tokenizer,
ptr_target_language,
result,
audio,
sr,
segment_idx,
audio_file_name_without_ext,
metadata,
whisper_words,
buffer_time,
speaker_name,
audio_folder,
create_bpe_tokenizer,
target_language,
):
"""Helper function to process transcription results and save audio segments"""
ptr_i = ptr_segment_idx + 1
ptr_sentence = ""
ptr_sentence_start = None
ptr_first_word = True
ptr_current_words = []

for ptr_segment in ptr_result["segments"]:
if "words" not in ptr_segment:
i = segment_idx + 1
sentence = ""
first_word = True
segment_content = ""

for segment in result["segments"]:

if "words" not in segment:
continue

for ptr_word_info in ptr_segment["words"]:
ptr_word = ptr_word_info.get("word", "").strip()
if not ptr_word:
segment_content = ""
segment_start = segment["words"][0].get("start", 0) - buffer_time
segment_start = max(0, segment_start)
segment_end = segment_start

for word_info in segment["words"]:
word = word_info.get("word", "").strip()
if not word:
continue

ptr_start_time = ptr_word_info.get("start", 0)
ptr_end_time = ptr_word_info.get("end", 0)
end_time = word_info.get("end", 0)

if ptr_create_bpe_tokenizer:
ptr_whisper_words.append(ptr_word)
if create_bpe_tokenizer:
whisper_words.append(word)

if ptr_first_word:
ptr_sentence_start = ptr_start_time
if len(ptr_current_words) == 0:
ptr_sentence_start = max(
ptr_sentence_start - ptr_buffer, 0)
else:
ptr_previous_end = ptr_current_words[-1].get(
"end", 0) if ptr_current_words else 0
ptr_sentence_start = max(
ptr_sentence_start - ptr_buffer,
(ptr_previous_end + ptr_start_time) / 2)
ptr_sentence = ptr_word
ptr_first_word = False
if first_word:
sentence = word
first_word = False
else:
ptr_sentence += " " + ptr_word

ptr_current_words.append(
{"word": ptr_word, "start": ptr_start_time, "end": ptr_end_time})

# Handle sentence splitting and audio saving
if ptr_word[-1] in ["!", ".",
"?"] or (ptr_end_time - ptr_sentence_start) > ptr_max_duration:
save_audio_segment(
ptr_audio,
ptr_sr,
ptr_sentence_start,
ptr_end_time,
ptr_sentence,
ptr_audio_file_name_without_ext,
ptr_i,
ptr_speaker_name,
ptr_audio_folder,
ptr_metadata,
ptr_max_duration,
ptr_buffer,
ptr_too_long_files,
ptr_target_language,
)
ptr_i += 1
ptr_first_word = True
ptr_current_words = []
ptr_sentence = ""
sentence += " " + word

if word[-1] in ["!", ".", "?"]:
segment_content += sentence + " "
segment_end = end_time + buffer_time
first_word = True
sentence = ""

segment_content = segment_content.strip()
if segment_content:
save_audio_segment(
audio,
sr,
segment_start,
segment_end,
sentence,
audio_file_name_without_ext,
i,
speaker_name,
audio_folder,
metadata,
target_language,
)



def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps):
def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps, max_duration=float("inf")):
"""
Enhanced VAD processing with better end-of-speech detection
"""
Expand All @@ -1859,23 +1819,24 @@ def process_audio_with_vad(wav, sr, vad_model, get_speech_timestamps):
vad_model,
sampling_rate=16000,
threshold=0.2, # Lower threshold to be more sensitive to speech
neg_threshold=0.001, # Negative threshold needs to be set explicitly due to a bug in Silero VAD
min_speech_duration_ms=200, # Shorter to catch brief utterances
max_speech_duration_s=float("inf"),
max_speech_duration_s= max_duration,
min_silence_duration_ms=300, # Shorter silence duration
window_size_samples=1024, # Smaller window for more precise detection
speech_pad_ms=300, # Add padding to end of speech segments
#window_size_samples=1024, # Smaller window for more precise detection # DEPRECATED: does nothing
speech_pad_ms=100, # Add padding to end of speech segments
)

# Scale timestamps back to original sample rate
scale_factor = sr / 16000
for segment in vad_segments:
segment["start"] = int(segment["start"] * scale_factor)
# Add extra padding at the end
segment["end"] = int(segment["end"] * scale_factor) + \
int(0.2 * sr) # Add 200ms padding
segment["end"] = int(segment["end"] * scale_factor) #+ \
# int(0.2 * sr) # Add 200ms padding

merged_segments = merge_short_segments(
vad_segments, min_duration=6.0, max_gap=0.5)
vad_segments, min_duration=6.0, sr=sr, max_gap=0.5)

debug_print(
f"VAD processing: {len(vad_segments)} original segments, {len(merged_segments)} after merging",
Expand Down