diff --git a/autosub/main.py b/autosub/main.py index 3448661..a3ba1b6 100644 --- a/autosub/main.py +++ b/autosub/main.py @@ -150,12 +150,12 @@ def main(): base_directory = os.getcwd() output_directory = os.path.join(base_directory, "output") audio_directory = os.path.join(base_directory, "audio") - video_file_name = input_file.split(os.sep)[-1].split(".")[0] - audio_file_name = os.path.join(audio_directory, video_file_name + ".wav") + video_prefix = os.path.splitext(os.path.basename(input_file))[0] + audio_file_name = os.path.join(audio_directory, video_prefix + ".wav") output_file_handle_dict = {} for format in args.format: - output_filename = os.path.join(output_directory, video_file_name + "." + format) + output_filename = os.path.join(output_directory, video_prefix + "." + format) print("Creating file: " + output_filename) output_file_handle_dict[format] = open(output_filename, "w") # For VTT format, write header @@ -165,14 +165,15 @@ def main(): # Clean audio/ directory for filename in os.listdir(audio_directory): - file_path = os.path.join(audio_directory, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + if filename.lower().endswith(".wav") and filename.startswith(video_prefix): + file_path = os.path.join(audio_directory, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Reason: %s' % (file_path, e)) # Extract audio from input video file extract_audio(input_file, audio_file_name) @@ -182,11 +183,10 @@ def main(): print("\nRunning inference:") - for file in tqdm(sort_alphanumeric(os.listdir(audio_directory))): - audio_segment_path = os.path.join(audio_directory, file) - - # Dont run inference on the original audio file - if audio_segment_path.split(os.sep)[-1] != audio_file_name.split(os.sep)[-1]: + for filename in tqdm(sort_alphanumeric(os.listdir(audio_directory))): + # Only run inference on relevant files, and don't run inference on the original audio file + if filename.startswith(video_prefix) and (filename != os.path.basename(audio_file_name)): + audio_segment_path = os.path.join(audio_directory, filename) ds_process_audio(ds, audio_segment_path, output_file_handle_dict, split_duration=args.split_duration) print("\n")