From 96b61b30f2db9e405fb7b0b85877e78d71413d4e Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Mon, 10 Jan 2022 19:03:29 -0800 Subject: [PATCH] 2.0rc1 (#387) --- docs/source/changelog/changelog_2.0.rst | 9 + montreal_forced_aligner/abc.py | 7 +- .../acoustic_modeling/base.py | 2 +- .../acoustic_modeling/sat.py | 2 +- .../acoustic_modeling/trainer.py | 1 + montreal_forced_aligner/alignment/base.py | 478 ++++++----------- montreal_forced_aligner/alignment/mixins.py | 45 +- .../alignment/multiprocessing.py | 487 ++++++++---------- .../alignment/pretrained.py | 100 ++-- montreal_forced_aligner/command_line/align.py | 6 +- .../command_line/transcribe.py | 1 - .../command_line/validate.py | 2 - .../corpus/acoustic_corpus.py | 77 ++- montreal_forced_aligner/corpus/base.py | 53 +- montreal_forced_aligner/corpus/classes.py | 52 +- montreal_forced_aligner/corpus/text_corpus.py | 21 +- montreal_forced_aligner/data.py | 4 +- montreal_forced_aligner/dictionary/mixins.py | 2 +- .../dictionary/multispeaker.py | 12 +- .../dictionary/pronunciation.py | 6 +- montreal_forced_aligner/helper.py | 40 +- montreal_forced_aligner/ivector/trainer.py | 15 + .../language_modeling/trainer.py | 90 ++-- montreal_forced_aligner/models.py | 38 +- .../transcription/transcriber.py | 86 ++-- tests/conftest.py | 42 +- tests/data/lab/weird_words.lab | 2 +- tests/test_commandline_lm.py | 4 +- tests/test_corpus.py | 12 +- 29 files changed, 817 insertions(+), 879 deletions(-) diff --git a/docs/source/changelog/changelog_2.0.rst b/docs/source/changelog/changelog_2.0.rst index 7c30eeec..ac681948 100644 --- a/docs/source/changelog/changelog_2.0.rst +++ b/docs/source/changelog/changelog_2.0.rst @@ -10,6 +10,15 @@ Beta releases ============= +2.0.0rc1 +-------- + +- Getting closer to stable release! +- Fixed some bugs in how transcription and alignment accuracy were calculated +- Added additional information to evaluation output files +- Added file listing average per-frame log-likelihoods by utterance for alignment +- Fixed a bug where having "" in a transcript would cause MFA to crash + 2.0.0b11 -------- diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index 6db2a4fa..50560b53 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -16,7 +16,6 @@ TYPE_CHECKING, Any, Dict, - Iterable, List, Optional, Set, @@ -47,7 +46,6 @@ "TrainerMixin", "DictionaryEntryType", "ReversedMappingType", - "Labels", "WordsType", "OneToOneMappingType", "OneToManyMappingType", @@ -58,7 +56,6 @@ # Configuration types MetaDict = Dict[str, Any] -Labels: Iterable[Any] CtmErrorDict: Dict[Tuple[str, int], str] # Dictionary types @@ -426,7 +423,9 @@ def parse_args(cls, args: Optional[Namespace], unknown_args: Optional[List[str]] val = unknown_args[i + 1] unknown_dict[name] = val for name, param_type in param_types.items(): - if name.endswith("_directory") or name.endswith("_path"): + if (name.endswith("_directory") and name != "audio_directory") or name.endswith( + "_path" + ): continue if args is not None and hasattr(args, name): params[name] = param_type(getattr(args, name)) diff --git a/montreal_forced_aligner/acoustic_modeling/base.py b/montreal_forced_aligner/acoustic_modeling/base.py index 717dc077..90c1f229 100644 --- a/montreal_forced_aligner/acoustic_modeling/base.py +++ b/montreal_forced_aligner/acoustic_modeling/base.py @@ -887,7 +887,7 @@ def meta(self) -> MetaDict: "architecture": self.architecture, "train_date": str(datetime.now()), "features": self.feature_options, - "phone_set_type": str(self.phone_set_type), + "phone_set_type": str(self.worker.phone_set_type), } return data diff --git a/montreal_forced_aligner/acoustic_modeling/sat.py b/montreal_forced_aligner/acoustic_modeling/sat.py index 1f60578d..b5cb56b1 100644 --- a/montreal_forced_aligner/acoustic_modeling/sat.py +++ b/montreal_forced_aligner/acoustic_modeling/sat.py @@ -214,7 +214,7 @@ def _trainer_initialization(self) -> None: os.rename(self.model_path, self.next_model_path) self.iteration = 1 - print(os.path.exists(os.path.join(self.previous_aligner.working_directory, "trans.0.ark"))) + if os.path.exists(os.path.join(self.previous_aligner.working_directory, "trans.0.ark")): for j in self.jobs: for path in j.construct_path_dictionary( diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index 0e892b8c..9e069e69 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -346,6 +346,7 @@ def align(self) -> None: f"Analyzing alignment diagnostics for {self.current_aligner.identifier} on the full corpus" ) self.compile_information() + self.collect_alignments() with open(done_path, "w"): pass except Exception as e: diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py index ddde72c3..ad1c7480 100644 --- a/montreal_forced_aligner/alignment/base.py +++ b/montreal_forced_aligner/alignment/base.py @@ -4,9 +4,7 @@ import multiprocessing as mp import os import shutil -import sys import time -import traceback from queue import Empty from typing import List, Optional @@ -15,22 +13,15 @@ from montreal_forced_aligner.abc import FileExporterMixin from montreal_forced_aligner.alignment.mixins import AlignMixin from montreal_forced_aligner.alignment.multiprocessing import ( - AliToCtmArguments, - AliToCtmFunction, ExportTextGridArguments, ExportTextGridProcessWorker, - PhoneCtmArguments, - PhoneCtmProcessWorker, - WordCtmArguments, - WordCtmProcessWorker, + PhoneAlignmentArguments, + PhoneAlignmentFunction, + WordAlignmentArguments, + WordAlignmentFunction, ) from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusPronunciationMixin -from montreal_forced_aligner.exceptions import AlignmentExportError -from montreal_forced_aligner.textgrid import ( - export_textgrid, - output_textgrid_writing_errors, - process_ctm_line, -) +from montreal_forced_aligner.textgrid import export_textgrid, output_textgrid_writing_errors from montreal_forced_aligner.utils import KaldiProcessWorker, Stopped __all__ = ["CorpusAligner"] @@ -52,8 +43,9 @@ class CorpusAligner(AcousticCorpusPronunciationMixin, AlignMixin, FileExporterMi def __init__(self, **kwargs): super().__init__(**kwargs) + self.export_output_directory = None - def word_ctm_arguments(self) -> List[WordCtmArguments]: + def word_alignment_arguments(self) -> List[WordAlignmentArguments]: """ Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.WordCtmProcessWorker` @@ -63,20 +55,25 @@ def word_ctm_arguments(self) -> List[WordCtmArguments]: Arguments for processing """ return [ - WordCtmArguments( - j.construct_path_dictionary(self.working_directory, "word", "ctm"), + WordAlignmentArguments( + os.path.join(self.working_log_directory, f"get_word_ctm.{j.name}.log"), + self.alignment_model_path, + round(self.frame_shift / 1000, 4), + self.cleanup_textgrids, + self.oov_word, + self.sanitize_function, j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.word_boundary_int_files(), + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), {d.name: d.reversed_word_mapping for d in self.dictionary_mapping.values()}, j.text_scp_data(), j.utt2spk_scp_data(), - self.sanitize_function, - self.cleanup_textgrids, - self.oov_word, ) for j in self.jobs ] - def phone_ctm_arguments(self) -> List[PhoneCtmArguments]: + def phone_alignment_arguments(self) -> List[PhoneAlignmentArguments]: """ Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmProcessWorker` @@ -86,13 +83,18 @@ def phone_ctm_arguments(self) -> List[PhoneCtmArguments]: Arguments for processing """ return [ - PhoneCtmArguments( - j.construct_path_dictionary(self.working_directory, "phone", "ctm"), - j.current_dictionary_names, - self.reversed_phone_mapping, + PhoneAlignmentArguments( + os.path.join(self.working_log_directory, f"get_phone_ctm.{j.name}.log"), + self.alignment_model_path, + round(self.frame_shift / 1000, 4), self.position_dependent_phones, self.cleanup_textgrids, self.optional_silence_phone, + j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.word_boundary_int_files(), + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + self.reversed_phone_mapping, ) for j in self.jobs ] @@ -110,7 +112,7 @@ def export_textgrid_arguments(self) -> List[ExportTextGridArguments]: ExportTextGridArguments( os.path.join(self.working_log_directory, f"export_textgrids.{j.name}.log"), self.frame_shift, - self.textgrid_output, + self.export_output_directory, self.backup_output_directory, ) for j in self.jobs @@ -121,186 +123,32 @@ def backup_output_directory(self) -> Optional[str]: """Backup directory if overwriting is not allowed""" return None - def ctms_to_textgrids_mp(self): - """ - Multiprocessing function for exporting alignment CTM information as TextGrids - - See Also - -------- - :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmFunction` - Multiprocessing helper function for converting ali archives to CTM format - :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmProcessWorker` - Multiprocessing helper class for processing CTM files - :meth:`.CorpusAligner.phone_ctm_arguments` - Job method for generating arguments for PhoneCtmProcessWorker - :class:`~montreal_forced_aligner.alignment.multiprocessing.WordCtmProcessWorker` - Multiprocessing helper class for processing word CTM files - :meth:`.CorpusAligner.word_ctm_arguments` - Job method for generating arguments for WordCtmProcessWorker - :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker` - Multiprocessing helper class for exporting TextGrid files - :meth:`.CorpusAligner.export_textgrid_arguments` - Job method for generating arguments for ExportTextGridProcessWorker - :kaldi_steps:`get_train_ctm` - Reference Kaldi script - - """ - export_begin = time.time() - manager = mp.Manager() - textgrid_errors = manager.dict() - error_catching = manager.dict() - stopped = Stopped() - if not self.overwrite: - os.makedirs(self.backup_output_directory, exist_ok=True) - - self.logger.debug("Beginning to process ctm files...") - word_procs = [] - phone_procs = [] - finished_processing = Stopped() - to_process_queue = mp.JoinableQueue() - for_write_queue = mp.JoinableQueue() - total_files = len(self.files) - word_ctm_args = self.word_ctm_arguments() - phone_ctm_args = self.phone_ctm_arguments() - export_args = self.export_textgrid_arguments() - for j in self.jobs: - word_p = WordCtmProcessWorker( - j.name, - to_process_queue, - stopped, - error_catching, - word_ctm_args[j.name], - ) - - word_procs.append(word_p) - word_p.start() - - phone_p = PhoneCtmProcessWorker( - j.name, - to_process_queue, - stopped, - error_catching, - phone_ctm_args[j.name], - ) - phone_p.start() - phone_procs.append(phone_p) - - export_procs = [] - for j in self.jobs: - export_proc = ExportTextGridProcessWorker( - for_write_queue, - stopped, - finished_processing, - textgrid_errors, - export_args[j.name], - ) - export_proc.start() - export_procs.append(export_proc) - try: - with tqdm.tqdm(total=total_files) as pbar: - while True: - try: - w_p, intervals = to_process_queue.get(timeout=1) - except Empty: - for proc in word_procs: - if not proc.finished_signal.stop_check(): - break - for proc in phone_procs: - if not proc.finished_signal.stop_check(): - break - else: - break - continue - to_process_queue.task_done() - if self.stopped.stop_check(): - self.logger.debug("Got stop check, exiting") - continue - utt = self.utterances[intervals[0].utterance] - if w_p == "word": - utt.add_word_intervals(intervals) - else: - utt.add_phone_intervals(intervals) - file = self.files[utt.file_name] - if file.is_fully_aligned: - tiers = file.aligned_data - output_path = file.construct_output_path( - self.textgrid_output, self.backup_output_directory - ) - duration = file.duration - for_write_queue.put((tiers, output_path, duration)) - pbar.update(1) - except Exception: - stopped.stop() - while True: - try: - _ = to_process_queue.get(timeout=1) - except Empty: - for proc in word_procs: - if not proc.finished_signal.stop_check(): - break - for proc in phone_procs: - if not proc.finished_signal.stop_check(): - break - else: - break - continue - to_process_queue.task_done() - exc_type, exc_value, exc_traceback = sys.exc_info() - error_catching["main"] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - finally: - finished_processing.stop() - self.logger.debug("Waiting for processes to finish...") - for i in range(self.num_jobs): - word_procs[i].join() - phone_procs[i].join() - to_process_queue.join() - - for_write_queue.join() - for i in range(self.num_jobs): - export_procs[i].join() - self.logger.debug(f"Export took {time.time() - export_begin} seconds") - - if error_catching: - self.logger.error("Error was encountered in processing CTMs") - for key, error in error_catching.items(): - self.logger.error(f"{key}:\n\n{error}") - raise AlignmentExportError(error_catching) - - if textgrid_errors: - self.logger.warning( - f"There were {len(textgrid_errors)} errors encountered in generating TextGrids. " - f"Check the output_errors.txt file in {os.path.join(self.textgrid_output)} " - f"for more details" - ) - output_textgrid_writing_errors(self.textgrid_output, textgrid_errors) - - def ali_to_ctm(self, word_mode=True): + def _collect_alignments(self, word_mode=True): """ - Convert alignment archives to CTM format + Process alignment archives to extract word or phone alignments Parameters ---------- word_mode: bool - Flag for generating word or phone CTMs + Flag for collecting word or phone alignments See Also -------- - :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmFunction` - Multiprocessing function - :meth:`.CorpusAligner.ali_to_word_ctm_arguments` + :class:`~montreal_forced_aligner.alignment.multiprocessing.WordAlignmentFunction` + Multiprocessing function for words alignments + :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneAlignmentFunction` + Multiprocessing function for phone alignments + :meth:`.CorpusAligner.word_alignment_arguments` Arguments for word CTMS - :meth:`.CorpusAligner.ali_to_phone_ctm_arguments` - Arguments for phone CTMS + :meth:`.CorpusAligner.phone_alignment_arguments` + Arguments for phone alignment """ if word_mode: - self.logger.info("Generating word CTM files from alignment lattices...") - jobs = self.ali_to_word_ctm_arguments() # Word CTM jobs + self.logger.info("Collecting word alignments from alignment lattices...") + jobs = self.word_alignment_arguments() # Word CTM jobs else: - self.logger.info("Generating phone CTM files from alignment lattices...") - jobs = self.ali_to_phone_ctm_arguments() # Phone CTM jobs - sum_errors = 0 + self.logger.info("Collecting phone alignments from alignment lattices...") + jobs = self.phone_alignment_arguments() # Phone CTM jobs with tqdm.tqdm(total=self.num_utterances) as pbar: if self.use_mp: manager = mp.Manager() @@ -309,14 +157,16 @@ def ali_to_ctm(self, word_mode=True): stopped = Stopped() procs = [] for i, args in enumerate(jobs): - function = AliToCtmFunction(args) + if word_mode: + function = WordAlignmentFunction(args) + else: + function = PhoneAlignmentFunction(args) p = KaldiProcessWorker(i, return_queue, function, error_dict, stopped) procs.append(p) p.start() while True: try: - done, errors = return_queue.get(timeout=1) - sum_errors += errors + utterance, intervals = return_queue.get(timeout=1) if stopped.stop_check(): continue except Empty: @@ -326,7 +176,11 @@ def ali_to_ctm(self, word_mode=True): else: break continue - pbar.update(done + errors) + pbar.update(1) + if word_mode: + self.utterances[utterance].add_word_intervals(intervals) + else: + self.utterances[utterance].add_phone_intervals(intervals) for p in procs: p.join() if error_dict: @@ -334,103 +188,113 @@ def ali_to_ctm(self, word_mode=True): raise v else: for args in jobs: - function = AliToCtmFunction(args) - for done, errors in function.run(): - sum_errors += errors - pbar.update(done + errors) - if sum_errors: - self.logger.warning(f"{errors} utterances had errors during creating CTM files.") + if word_mode: + function = WordAlignmentFunction(args) + else: + function = PhoneAlignmentFunction(args) + for utterance, intervals in function.run(): + if word_mode: + self.utterances[utterance].add_word_intervals(intervals) + else: + self.utterances[utterance].add_phone_intervals(intervals) + pbar.update(1) + + def collect_word_alignments(self): + self._collect_alignments(True) + + def collect_phone_alignments(self): + self._collect_alignments(False) - def convert_ali_to_textgrids(self) -> None: + def collect_alignments(self): + if self.alignment_done: + return + self.collect_word_alignments() + self.collect_phone_alignments() + self.alignment_done = True + + def export_textgrids(self) -> None: """ - Multiprocessing function that aligns based on the current model. + Exports alignments to TextGrid files See Also -------- - :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmFunction` - Multiprocessing helper function for each job - :meth:`.CorpusAligner.ali_to_word_ctm_arguments` - Job method for generating arguments for this function - :meth:`.CorpusAligner.ali_to_phone_ctm_arguments` - Job method for generating arguments for this function - :kaldi_steps:`get_train_ctm` - Reference Kaldi script + :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker` + Multiprocessing helper function for TextGrid export + :meth:`.CorpusAligner.export_textgrid_arguments` + Job method for TextGrid export """ - os.makedirs(self.textgrid_output, exist_ok=True) - self.logger.info("Generating CTMs from alignment...") - self.ali_to_ctm(True) - self.ali_to_ctm(False) - self.logger.info("Finished generating CTMs!") - - self.logger.info("Exporting TextGrids from CTMs...") - if self.use_mp: - self.ctms_to_textgrids_mp() - else: - self.ctms_to_textgrids_non_mp() - self.logger.info("Finished exporting TextGrids!") + begin = time.time() + self.logger.info("Exporting TextGrids...") + os.makedirs(self.export_output_directory, exist_ok=True) + if self.backup_output_directory: + os.makedirs(self.backup_output_directory, exist_ok=True) - def ctms_to_textgrids_non_mp(self) -> None: - """ - Parse CTM files to TextGrids without using multiprocessing - """ - self.log_debug("Not using multiprocessing for TextGrid export") export_errors = {} - w_args = self.word_ctm_arguments() - p_args = self.phone_ctm_arguments() - for j in self.jobs: - - word_arguments = w_args[j.name] - phone_arguments = p_args[j.name] - self.logger.debug(f"Parsing ctms for job {j.name}...") - - for dict_name in word_arguments.dictionaries: - with open(word_arguments.ctm_paths[dict_name], "r") as f: - for line in f: - line = line.strip() - if line == "": - continue - interval = process_ctm_line(line) - utt = self.utterances[interval.utterance] - dictionary = self.get_dictionary(utt.speaker_name) - label = dictionary.reversed_word_mapping[int(interval.label)] - - interval.label = label - utt.add_word_intervals(interval) - - for dict_name in phone_arguments.dictionaries: - with open(phone_arguments.ctm_paths[dict_name], "r") as f: - for line in f: - line = line.strip() - if line == "": - continue - interval = process_ctm_line(line) - utt = self.utterances[interval.utterance] - dictionary = self.get_dictionary(utt.speaker_name) - - label = dictionary.reversed_phone_mapping[int(interval.label)] - if self.position_dependent_phones: - for p in dictionary.positions: - if label.endswith(p): - label = label[: -1 * len(p)] - interval.label = label - utt.add_phone_intervals(interval) - for file in self.files: - data = file.aligned_data + total_files = len(self.files) + with tqdm.tqdm(total=total_files) as pbar: + if self.use_mp: + manager = mp.Manager() + textgrid_errors = manager.dict() + stopped = Stopped() - backup_output_directory = None - if not self.overwrite: - backup_output_directory = self.backup_output_directory - os.makedirs(backup_output_directory, exist_ok=True) - output_path = file.construct_output_path(self.textgrid_output, backup_output_directory) - export_textgrid(data, output_path, file.duration, self.frame_shift) + finished_processing = Stopped() + for_write_queue = mp.JoinableQueue() + export_args = self.export_textgrid_arguments() + + export_procs = [] + for j in self.jobs: + export_proc = ExportTextGridProcessWorker( + for_write_queue, + stopped, + finished_processing, + textgrid_errors, + export_args[j.name], + ) + export_proc.start() + export_procs.append(export_proc) + try: + for file in self.files: + tiers = file.aligned_data + output_path = file.construct_output_path( + self.export_output_directory, self.backup_output_directory + ) + duration = file.duration + for_write_queue.put((tiers, output_path, duration)) + pbar.update(1) + except Exception: + stopped.stop() + raise + finally: + finished_processing.stop() + + for_write_queue.join() + for i in range(self.num_jobs): + export_procs[i].join() + export_errors.update(textgrid_errors) + else: + self.log_debug("Not using multiprocessing for TextGrid export") + for file in self.files: + data = file.aligned_data + + backup_output_directory = None + if not self.overwrite: + backup_output_directory = self.backup_output_directory + os.makedirs(backup_output_directory, exist_ok=True) + output_path = file.construct_output_path( + self.export_output_directory, backup_output_directory + ) + export_textgrid(data, output_path, file.duration, self.frame_shift) + pbar.update(1) if export_errors: self.logger.warning( f"There were {len(export_errors)} errors encountered in generating TextGrids. " - f"Check the output_errors.txt file in {os.path.join(self.textgrid_output)} " + f"Check the output_errors.txt file in {os.path.join(self.export_output_directory)} " f"for more details" ) - output_textgrid_writing_errors(self.textgrid_output, export_errors) + output_textgrid_writing_errors(self.export_output_directory, export_errors) + self.logger.info("Finished exporting TextGrids!") + self.logger.debug(f"Exported TextGrids in a total of {time.time() - begin} seconds") def export_files(self, output_directory: str) -> None: """ @@ -441,59 +305,9 @@ def export_files(self, output_directory: str) -> None: output_directory: str Directory to save to """ - begin = time.time() - self.textgrid_output = output_directory + self.export_output_directory = output_directory if self.backup_output_directory is not None and os.path.exists( self.backup_output_directory ): shutil.rmtree(self.backup_output_directory, ignore_errors=True) - self.convert_ali_to_textgrids() - self.logger.debug(f"Exported TextGrids in a total of {time.time() - begin} seconds") - - def ali_to_word_ctm_arguments(self) -> List[AliToCtmArguments]: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmFunction` - - Returns - ------- - list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmArguments`] - Arguments for processing - """ - return [ - AliToCtmArguments( - os.path.join(self.working_log_directory, f"get_word_ctm.{j.name}.log"), - j.current_dictionary_names, - j.construct_path_dictionary(self.working_directory, "ali", "ark"), - j.construct_path_dictionary(self.data_directory, "text", "int.scp"), - j.word_boundary_int_files(), - round(self.frame_shift / 1000, 4), - self.alignment_model_path, - j.construct_path_dictionary(self.working_directory, "word", "ctm"), - True, - ) - for j in self.jobs - ] - - def ali_to_phone_ctm_arguments(self) -> List[AliToCtmArguments]: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmFunction` - - Returns - ------- - list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmArguments`] - Arguments for processing - """ - return [ - AliToCtmArguments( - os.path.join(self.working_log_directory, f"get_phone_ctm.{j.name}.log"), - j.current_dictionary_names, - j.construct_path_dictionary(self.working_directory, "ali", "ark"), - j.construct_path_dictionary(self.data_directory, "text", "int.scp"), - j.word_boundary_int_files(), - round(self.frame_shift / 1000, 4), - self.alignment_model_path, - j.construct_path_dictionary(self.working_directory, "phone", "ctm"), - False, - ) - for j in self.jobs - ] + self.export_textgrids() diff --git a/montreal_forced_aligner/alignment/mixins.py b/montreal_forced_aligner/alignment/mixins.py index f9f35565..b42edcfd 100644 --- a/montreal_forced_aligner/alignment/mixins.py +++ b/montreal_forced_aligner/alignment/mixins.py @@ -1,6 +1,7 @@ """Class definitions for alignment mixins""" from __future__ import annotations +import csv import logging import multiprocessing as mp import os @@ -309,7 +310,7 @@ def align_utterances(self) -> None: p.start() while True: try: - utterance, succeeded = return_queue.get(timeout=1) + utterance, log_likelihood = return_queue.get(timeout=1) if stopped.stop_check(): continue except Empty: @@ -319,11 +320,17 @@ def align_utterances(self) -> None: else: break continue - if not succeeded and hasattr(self, "utterances"): - self.utterances[utterance].phone_labels = [] - self.utterances[utterance].word_labels = [] - else: - pbar.update(1) + if hasattr(self, "utterances"): + if hasattr(self, "frame_shift"): + num_frames = int( + self.utterances[utterance].duration * self.frame_shift + ) + else: + num_frames = self.utterances[utterance].duration + self.utterances[utterance].alignment_log_likelihood = ( + log_likelihood / num_frames + ) + pbar.update(1) for p in procs: p.join() if error_dict: @@ -333,12 +340,18 @@ def align_utterances(self) -> None: self.logger.debug("Not using multiprocessing...") for args in self.align_arguments(): function = AlignFunction(args) - for utterance, succeeded in function.run(): - if not succeeded and hasattr(self, "utterances"): - self.utterances[utterance].phone_labels = [] - self.utterances[utterance].word_labels = [] - else: - pbar.update(1) + for utterance, log_likelihood in function.run(): + if hasattr(self, "utterances"): + if hasattr(self, "frame_shift"): + num_frames = int( + self.utterances[utterance].duration * self.frame_shift + ) + else: + num_frames = self.utterances[utterance].duration + self.utterances[utterance].alignment_log_likelihood = ( + log_likelihood / num_frames + ) + pbar.update(1) self.compile_information() self.logger.debug(f"Alignment round took {time.time() - begin}") @@ -383,6 +396,14 @@ def compile_information(self): average_logdet_frames += data["logdet_frames"] average_logdet_sum += data["logdet"] * data["logdet_frames"] + if hasattr(self, "utterances"): + csv_path = os.path.join(self.working_directory, "alignment_log_likelihood.csv") + with open(csv_path, "w", newline="", encoding="utf8") as f: + writer = csv.writer(f) + writer.writerow(["utterance", "loglikelihood"]) + for u in self.utterances: + writer.writerow([u.name, u.alignment_log_likelihood]) + if not avg_like_frames: self.logger.warning( "No files were aligned, this likely indicates serious problems with the aligner." diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py index 3109ccf9..aba357e9 100644 --- a/montreal_forced_aligner/alignment/multiprocessing.py +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -20,19 +20,17 @@ from montreal_forced_aligner.utils import KaldiFunction, Stopped, thirdparty_binary if TYPE_CHECKING: - from montreal_forced_aligner.abc import CtmErrorDict, MetaDict + from montreal_forced_aligner.abc import MetaDict __all__ = [ - "WordCtmProcessWorker", - "PhoneCtmProcessWorker", + "WordAlignmentFunction", + "PhoneAlignmentFunction", "ExportTextGridProcessWorker", - "WordCtmArguments", - "PhoneCtmArguments", + "WordAlignmentArguments", + "PhoneAlignmentArguments", "ExportTextGridArguments", "AlignFunction", "AlignArguments", - "AliToCtmFunction", - "AliToCtmArguments", "compile_information_func", "CompileInformationArguments", "CompileTrainGraphsFunction", @@ -40,42 +38,38 @@ ] -class AliToCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmFunction`""" +class WordAlignmentArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.WordAlignmentFunction`""" log_path: str + model_path: str + frame_shift: float + cleanup_textgrids: bool + oov_word: str + sanitize_function: MultispeakerSanitizationFunction dictionaries: List[str] ali_paths: Dict[str, str] - text_int_paths: Dict[str, str] word_boundary_int_paths: Dict[str, str] - frame_shift: float - model_path: str - ctm_paths: Dict[str, str] - word_mode: bool - - -class WordCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.WordCtmProcessWorker`""" - - ctm_paths: Dict[str, str] - dictionaries: List[str] + text_int_paths: Dict[str, str] reversed_word_mappings: Dict[str, Dict[int, str]] utterance_texts: Dict[str, Dict[str, str]] utterance_speakers: Dict[str, Dict[str, str]] - sanitize_function: MultispeakerSanitizationFunction - cleanup_textgrids: bool - oov_word: str -class PhoneCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmProcessWorker`""" +class PhoneAlignmentArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneAlignmentFunction`""" - ctm_paths: Dict[str, str] - dictionaries: List[str] - reversed_phone_mapping: Dict[int, str] + log_path: str + model_path: str + frame_shift: float position_dependent_phones: bool cleanup_textgrids: bool silence_phone: str + dictionaries: List[str] + ali_paths: Dict[str, str] + word_boundary_int_paths: Dict[str, str] + text_int_paths: Dict[str, str] + reversed_phone_mapping: Dict[int, str] class ExportTextGridArguments(NamedTuple): @@ -280,13 +274,6 @@ class AlignFunction(KaldiFunction): Arguments for the function """ - progress_pattern = re.compile( - r"^LOG \(gmm-align-compiled.*gmm-align-compiled.cc:127\) (?P.*)" - ) - error_pattern = re.compile( - r"^WARNING \(gmm-align-compiled.*Did not successfully decode file (?P.*),.*" - ) - def __init__(self, args: AlignArguments): self.log_path = args.log_path self.dictionaries = args.dictionaries @@ -315,6 +302,7 @@ def run(self): f"scp:{fst_path}", feature_string, f"ark:{ali_path}", + "ark,t:-", ] boost_proc = subprocess.Popen( @@ -331,27 +319,16 @@ def run(self): ) align_proc = subprocess.Popen( com, - stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=log_file, encoding="utf8", stdin=boost_proc.stdout, env=os.environ, ) - for line in align_proc.stderr: - log_file.write(line) + for line in align_proc.stdout: line = line.strip() - if "Overall" in line: - continue - if "Retried" in line: - continue - if "Done" in line: - continue - m = self.error_pattern.match(line) - if m: - yield m.group("utterance"), False - else: - m = self.progress_pattern.match(line) - if m: - yield m.group("utterance"), True + utterance, log_likelihood = line.split() + yield utterance, float(log_likelihood) def compile_information_func(align_log_path: str) -> Dict[str, Union[List[str], float, int]]: @@ -406,18 +383,16 @@ def compile_information_func(align_log_path: str) -> Dict[str, Union[List[str], return data -class AliToCtmFunction(KaldiFunction): +class WordAlignmentFunction(KaldiFunction): """ - Multiprocessing function to convert alignment archives into CTM files + Multiprocessing function to collect word alignments from the aligned lattice See Also -------- - :meth:`.CorpusAligner.ctms_to_textgrids_mp` + :meth:`.CorpusAligner.collect_word_alignments` Main function that calls this function in parallel - :meth:`.CorpusAligner.ali_to_word_ctm_arguments` - Job method for generating arguments for this function - :meth:`.CorpusAligner.ali_to_phone_ctm_arguments` + :meth:`.CorpusAligner.word_alignments_arguments` Job method for generating arguments for this function :kaldi_src:`linear-to-nbest` Relevant Kaldi binary @@ -429,148 +404,16 @@ class AliToCtmFunction(KaldiFunction): Relevant Kaldi binary :kaldi_src:`nbest-to-ctm` Relevant Kaldi binary + :kaldi_steps:`get_train_ctm` + Reference Kaldi script Parameters ---------- - args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmArguments` + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.WordAlignmentArguments` Arguments for the function """ - progress_pattern = re.compile( - r"^LOG.* Converted (?P\d+) linear lattices to ctm format; (?P\d+) had errors." - ) - - def __init__(self, args: AliToCtmArguments): - self.log_path = args.log_path - self.dictionaries = args.dictionaries - self.ali_paths = args.ali_paths - self.text_int_paths = args.text_int_paths - self.word_boundary_int_paths = args.word_boundary_int_paths - self.frame_shift = args.frame_shift - self.model_path = args.model_path - self.ctm_paths = args.ctm_paths - self.word_mode = args.word_mode - - def run(self): - """Run the function""" - with open(self.log_path, "w", encoding="utf8") as log_file: - for dict_name in self.dictionaries: - ali_path = self.ali_paths[dict_name] - text_int_path = self.text_int_paths[dict_name] - ctm_path = self.ctm_paths[dict_name] - word_boundary_int_path = self.word_boundary_int_paths[dict_name] - if os.path.exists(ctm_path): - return - lin_proc = subprocess.Popen( - [ - thirdparty_binary("linear-to-nbest"), - "ark:" + ali_path, - "ark:" + text_int_path, - "", - "", - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - align_words_proc = subprocess.Popen( - [ - thirdparty_binary("lattice-align-words"), - word_boundary_int_path, - self.model_path, - "ark:-", - "ark:-", - ], - stdin=lin_proc.stdout, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - if self.word_mode: - nbest_proc = subprocess.Popen( - [ - thirdparty_binary("nbest-to-ctm"), - f"--frame-shift={self.frame_shift}", - "ark:-", - ctm_path, - ], - stderr=subprocess.PIPE, - stdin=align_words_proc.stdout, - env=os.environ, - encoding="utf8", - ) - else: - phone_proc = subprocess.Popen( - [ - thirdparty_binary("lattice-to-phone-lattice"), - self.model_path, - "ark:-", - "ark:-", - ], - stdout=subprocess.PIPE, - stdin=align_words_proc.stdout, - stderr=log_file, - env=os.environ, - ) - nbest_proc = subprocess.Popen( - [ - thirdparty_binary("nbest-to-ctm"), - f"--frame-shift={self.frame_shift}", - "ark:-", - ctm_path, - ], - stdin=phone_proc.stdout, - stderr=subprocess.PIPE, - env=os.environ, - encoding="utf8", - ) - for line in nbest_proc.stderr: - m = self.progress_pattern.match(line.strip()) - if m: - yield int(m.group("done")), int(m.group("errors")) - - -class WordCtmProcessWorker(mp.Process): - """ - Multiprocessing worker for loading word CTM files - - See Also - -------- - :meth:`.CorpusAligner.ctms_to_textgrids_mp` - Main function that runs this worker in parallel - - Parameters - ---------- - job_name: int - Job name - to_process_queue: :class:`~multiprocessing.Queue` - Return queue of jobs for later workers to process - stopped: :class:`~montreal_forced_aligner.utils.Stopped` - Stop check for processing - error_catching: dict[tuple[str, int], str] - Dictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.WordCtmArguments` - Arguments to pass to the CTM processing function - """ - - def __init__( - self, - job_name: int, - to_process_queue: mp.Queue, - stopped: Stopped, - error_catching: CtmErrorDict, - arguments: WordCtmArguments, - ): - mp.Process.__init__(self) - self.job_name = job_name - self.dictionaries = arguments.dictionaries - self.ctm_paths = arguments.ctm_paths - self.to_process_queue = to_process_queue - self.stopped = stopped - self.error_catching = error_catching - self.finished_signal = Stopped() - + def __init__(self, arguments: WordAlignmentArguments): self.arguments = arguments def cleanup_intervals(self, utterance_name: str, intervals: List[CtmInterval]): @@ -622,82 +465,106 @@ def cleanup_intervals(self, utterance_name: str, intervals: List[CtmInterval]): raise return actual_labels - def run(self) -> None: - """ - Run the word processing - """ - cur_utt = None - intervals = [] - try: - for dict_name in self.dictionaries: - ctm_path = self.ctm_paths[dict_name] - with open(ctm_path, "r") as word_file: - for line in word_file: - line = line.strip() - if not line: - continue + def run(self): + """Run the function""" + with open(self.arguments.log_path, "w", encoding="utf8") as log_file: + for dict_name in self.arguments.dictionaries: + cur_utt = None + intervals = [] + ali_path = self.arguments.ali_paths[dict_name] + text_int_path = self.arguments.text_int_paths[dict_name] + word_boundary_int_path = self.arguments.word_boundary_int_paths[dict_name] + lin_proc = subprocess.Popen( + [ + thirdparty_binary("linear-to-nbest"), + f"ark:{ali_path}", + f"ark:{text_int_path}", + "", + "", + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + align_words_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-align-words"), + word_boundary_int_path, + self.arguments.model_path, + "ark:-", + "ark:-", + ], + stdin=lin_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + nbest_proc = subprocess.Popen( + [ + thirdparty_binary("nbest-to-ctm"), + "--print-args=false", + f"--frame-shift={self.arguments.frame_shift}", + "ark:-", + "-", + ], + stderr=log_file, + stdin=align_words_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + encoding="utf8", + ) + for line in nbest_proc.stdout: + line = line.strip() + if not line: + continue + try: interval = process_ctm_line(line) - if cur_utt is None: - cur_utt = interval.utterance - if cur_utt != interval.utterance: - - self.to_process_queue.put( - ("word", self.cleanup_intervals(cur_utt, intervals)) - ) - intervals = [] - cur_utt = interval.utterance - intervals.append(interval) - if intervals: - self.to_process_queue.put(("word", self.cleanup_intervals(cur_utt, intervals))) + except ValueError: + continue + if cur_utt is None: + cur_utt = interval.utterance + if cur_utt != interval.utterance: + yield cur_utt, self.cleanup_intervals(cur_utt, intervals) + intervals = [] + cur_utt = interval.utterance + intervals.append(interval) + if intervals: + yield cur_utt, self.cleanup_intervals(cur_utt, intervals) - except Exception: - self.stopped.stop() - exc_type, exc_value, exc_traceback = sys.exc_info() - self.error_catching[("word", self.job_name)] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - finally: - self.finished_signal.stop() +class PhoneAlignmentFunction(KaldiFunction): -class PhoneCtmProcessWorker(mp.Process): """ - Multiprocessing worker for loading phone CTM files + Multiprocessing function to collect phone alignments from the aligned lattice See Also -------- - :meth:`.CorpusAligner.ctms_to_textgrids_mp` - Main function that runs this worker in parallel + :meth:`.CorpusAligner.collect_phone_alignments` + Main function that calls this function in parallel + :meth:`.CorpusAligner.phone_alignments_arguments` + Job method for generating arguments for this function + :kaldi_src:`linear-to-nbest` + Relevant Kaldi binary + :kaldi_src:`lattice-determinize-pruned` + Relevant Kaldi binary + :kaldi_src:`lattice-align-words` + Relevant Kaldi binary + :kaldi_src:`lattice-to-phone-lattice` + Relevant Kaldi binary + :kaldi_src:`nbest-to-ctm` + Relevant Kaldi binary + :kaldi_steps:`get_train_ctm` + Reference Kaldi script Parameters ---------- - job_name: int - Job name - to_process_queue: :class:`~multiprocessing.Queue` - Return queue of jobs for later workers to process - stopped: :class:`~montreal_forced_aligner.utils.Stopped` - Stop check for processing - error_catching: dict[tuple[str, int], str] - Dictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmArguments` - Arguments to pass to the CTM processing function + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneAlignmentArguments` + Arguments for the function """ - def __init__( - self, - job_name: int, - to_process_queue: mp.Queue, - stopped: Stopped, - error_catching: CtmErrorDict, - arguments: PhoneCtmArguments, - ): - mp.Process.__init__(self) - self.job_name = job_name + def __init__(self, arguments: PhoneAlignmentArguments): self.arguments = arguments - self.to_process_queue = to_process_queue - self.stopped = stopped - self.error_catching = error_catching - self.finished_signal = Stopped() def cleanup_intervals(self, intervals: List[CtmInterval]): actual_labels = [] @@ -711,38 +578,86 @@ def cleanup_intervals(self, intervals: List[CtmInterval]): actual_labels.append(interval) return actual_labels - def run(self) -> None: - """Run the phone processing""" - cur_utt = None - intervals = [] - try: + def run(self): + """Run the function""" + with open(self.arguments.log_path, "w", encoding="utf8") as log_file: for dict_name in self.arguments.dictionaries: - ctm_path = self.arguments.ctm_paths[dict_name] - with open(ctm_path, "r") as word_file: - for line in word_file: - line = line.strip() - if not line: - continue - interval = process_ctm_line(line) - if cur_utt is None: - cur_utt = interval.utterance - if cur_utt != interval.utterance: - - self.to_process_queue.put(("phone", self.cleanup_intervals(intervals))) - intervals = [] - cur_utt = interval.utterance - intervals.append(interval) - if intervals: - self.to_process_queue.put(("phone", self.cleanup_intervals(intervals))) + cur_utt = None + intervals = [] + ali_path = self.arguments.ali_paths[dict_name] + text_int_path = self.arguments.text_int_paths[dict_name] + word_boundary_int_path = self.arguments.word_boundary_int_paths[dict_name] + lin_proc = subprocess.Popen( + [ + thirdparty_binary("linear-to-nbest"), + f"ark:{ali_path}", + f"ark:{text_int_path}", + "", + "", + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + align_words_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-align-words"), + word_boundary_int_path, + self.arguments.model_path, + "ark:-", + "ark:-", + ], + stdin=lin_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + phone_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-to-phone-lattice"), + self.arguments.model_path, + "ark:-", + "ark:-", + ], + stdout=subprocess.PIPE, + stdin=align_words_proc.stdout, + stderr=log_file, + env=os.environ, + ) + nbest_proc = subprocess.Popen( + [ + thirdparty_binary("nbest-to-ctm"), + "--print-args=false", + f"--frame-shift={self.arguments.frame_shift}", + "ark:-", + "-", + ], + stdin=phone_proc.stdout, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + encoding="utf8", + ) + for line in nbest_proc.stdout: + line = line.strip() + if not line: + continue - except Exception: - self.stopped.stop() - exc_type, exc_value, exc_traceback = sys.exc_info() - self.error_catching[("phone", self.job_name)] = traceback.format_exception( - exc_type, exc_value, exc_traceback - ) - finally: - self.finished_signal.stop() + try: + interval = process_ctm_line(line) + except ValueError: + continue + if cur_utt is None: + cur_utt = interval.utterance + if cur_utt != interval.utterance: + + yield cur_utt, self.cleanup_intervals(intervals) + intervals = [] + cur_utt = interval.utterance + intervals.append(interval) + if intervals: + yield cur_utt, self.cleanup_intervals(intervals) class ExportTextGridProcessWorker(mp.Process): diff --git a/montreal_forced_aligner/alignment/pretrained.py b/montreal_forced_aligner/alignment/pretrained.py index b64b70e5..b3ddde12 100644 --- a/montreal_forced_aligner/alignment/pretrained.py +++ b/montreal_forced_aligner/alignment/pretrained.py @@ -1,14 +1,12 @@ """Class definitions for aligning with pretrained acoustic models""" from __future__ import annotations -import multiprocessing as mp import os import subprocess import time from collections import Counter, defaultdict from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional -import tqdm import yaml from montreal_forced_aligner.abc import TopLevelMfaWorker @@ -16,7 +14,6 @@ from montreal_forced_aligner.exceptions import KaldiProcessingError from montreal_forced_aligner.helper import align_phones, parse_old_features from montreal_forced_aligner.models import AcousticModel -from montreal_forced_aligner.textgrid import parse_aligned_textgrid from montreal_forced_aligner.utils import log_kaldi_errors, run_mp, run_non_mp, thirdparty_binary if TYPE_CHECKING: @@ -143,8 +140,9 @@ def __init__( **kwargs, ): self.acoustic_model = AcousticModel(acoustic_model_path) - kwargs.update(self.acoustic_model.parameters) - super().__init__(**kwargs) + kw = self.acoustic_model.parameters + kw.update(kwargs) + super().__init__(**kw) @property def working_directory(self) -> str: @@ -251,7 +249,6 @@ def workflow_identifier(self) -> str: def evaluate( self, - reference_directory: str, mapping: Optional[Dict[str, str]] = None, output_directory: Optional[str] = None, ) -> None: @@ -268,60 +265,9 @@ def evaluate( Directory to save results, if not specified, it will be saved in the log directory """ # Set up - per_utterance_phone_intervals = {} self.log_info("Evaluating alignments...") self.log_debug(f"Mapping: {mapping}") - indices = [] - jobs = [] - self.log_info("Loading reference alignments...") - for root, _, files in os.walk(reference_directory, followlinks=True): - root_speaker = os.path.basename(root) - for f in files: - if f.endswith(".TextGrid"): - file_name = f.replace(".TextGrid", "") - if file_name not in self.files: - continue - if self.use_mp: - indices.append(file_name) - jobs.append((os.path.join(root, f), root_speaker)) - else: - file = self.files[file_name] - intervals = parse_aligned_textgrid(os.path.join(root, f), root_speaker) - for u in file.utterances: - if u.begin is None or u.end is None: - for v in intervals.values(): - per_utterance_phone_intervals[u.name] = v - else: - if u.speaker_name not in intervals: - continue - utterance_name = u.name - if utterance_name not in per_utterance_phone_intervals: - per_utterance_phone_intervals[utterance_name] = [] - for interval in intervals[u.speaker_name]: - if interval.begin >= u.begin and interval.end <= u.end: - per_utterance_phone_intervals[utterance_name].append( - interval - ) - if self.use_mp: - with mp.Pool(self.num_jobs) as pool, tqdm.tqdm(total=len(jobs)) as pbar: - gen = pool.starmap(parse_aligned_textgrid, jobs) - for i, intervals in enumerate(gen): - pbar.update(1) - file_name = indices[i] - file = self.files[file_name] - for u in file.utterances: - if u.begin is None or u.end is None: - for v in intervals.values(): - per_utterance_phone_intervals[u.name] = v - else: - if u.speaker_name not in intervals: - continue - utterance_name = u.name - if utterance_name not in per_utterance_phone_intervals: - per_utterance_phone_intervals[utterance_name] = [] - for interval in intervals[u.speaker_name]: - if interval.begin >= u.begin and interval.end <= u.end: - per_utterance_phone_intervals[utterance_name].append(interval) + score_count = 0 score_sum = 0 phone_edit_sum = 0 @@ -331,25 +277,43 @@ def evaluate( else: csv_path = os.path.join(self.working_log_directory, "alignment_evaluation.csv") with open(csv_path, "w", encoding="utf8") as f: - f.write("utterance,score,phone_error_rate\n") - for utterance_name, intervals in per_utterance_phone_intervals.items(): - if not intervals: + f.write( + "utterance,file,speaker,duration,word_count,oov_count,reference_phone_count,score,phone_error_rate\n" + ) + for utterance in self.utterances: + if not utterance.reference_phone_labels: continue - if not self.utterances[utterance_name].phone_labels: + speaker = utterance.speaker_name + file = utterance.file_name + duration = utterance.duration + reference_phone_count = len(utterance.reference_phone_labels) + word_count = len(utterance.text.split()) + oov_count = len(utterance.oovs) + if not utterance.phone_labels: # couldn't be aligned + utterance.alignment_score = None + utterance.phone_error_rate = len(utterance.reference_phone_labels) + f.write( + f"{utterance.name},{file},{speaker},{duration},{word_count},{oov_count},{reference_phone_count},na,{len(utterance.reference_phone_labels)}\n" + ) + continue score, phone_error_rate = align_phones( - intervals, - self.utterances[utterance_name].phone_labels, + utterance.reference_phone_labels, + utterance.phone_labels, self.optional_silence_phone, mapping, ) if score is None: continue - f.write(f"{utterance_name},{score},{phone_error_rate}\n") + utterance.alignment_score = score + utterance.phone_error_rate = phone_error_rate + f.write( + f"{utterance.name},{file},{speaker},{duration},{word_count},{oov_count},{reference_phone_count},{score},{phone_error_rate}\n" + ) score_count += 1 score_sum += score - phone_edit_sum += int(phone_error_rate * len(intervals)) - phone_length_sum += len(intervals) + phone_edit_sum += int(phone_error_rate * reference_phone_count) + phone_length_sum += reference_phone_count self.logger.info(f"Average overlap score: {score_sum/score_count}") self.logger.info(f"Average phone error rate: {phone_edit_sum/phone_length_sum}") @@ -378,7 +342,7 @@ def align(self) -> None: self.align_utterances() self.compile_information() - + self.collect_alignments() except Exception as e: with open(dirty_path, "w"): pass diff --git a/montreal_forced_aligner/command_line/align.py b/montreal_forced_aligner/command_line/align.py index ec94e3ae..6f4e5db8 100644 --- a/montreal_forced_aligner/command_line/align.py +++ b/montreal_forced_aligner/command_line/align.py @@ -31,7 +31,6 @@ def align_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) -> N aligner = PretrainedAligner( acoustic_model_path=args.acoustic_model_path, corpus_directory=args.corpus_directory, - audio_directory=args.audio_directory, dictionary_path=args.dictionary_path, temporary_directory=args.temporary_directory, **PretrainedAligner.parse_parameters(args.config_path, args, unknown_args), @@ -44,9 +43,8 @@ def align_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) -> N if getattr(args, "custom_mapping_path", ""): with open(args.custom_mapping_path, "r", encoding="utf8") as f: mapping = yaml.safe_load(f) - aligner.evaluate( - args.reference_directory, mapping, output_directory=args.output_directory - ) + aligner.load_reference_alignments(args.reference_directory) + aligner.evaluate(mapping, output_directory=args.output_directory) except Exception: aligner.dirty = True raise diff --git a/montreal_forced_aligner/command_line/transcribe.py b/montreal_forced_aligner/command_line/transcribe.py index 1ea87e2e..643ca7fb 100644 --- a/montreal_forced_aligner/command_line/transcribe.py +++ b/montreal_forced_aligner/command_line/transcribe.py @@ -30,7 +30,6 @@ def transcribe_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) acoustic_model_path=args.acoustic_model_path, language_model_path=args.language_model_path, corpus_directory=args.corpus_directory, - audio_directory=args.audio_directory, dictionary_path=args.dictionary_path, temporary_directory=args.temporary_directory, **Transcriber.parse_parameters(args.config_path, args, unknown_args), diff --git a/montreal_forced_aligner/command_line/validate.py b/montreal_forced_aligner/command_line/validate.py index b22c44eb..b2dee937 100644 --- a/montreal_forced_aligner/command_line/validate.py +++ b/montreal_forced_aligner/command_line/validate.py @@ -30,7 +30,6 @@ def validate_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) - validator = PretrainedValidator( acoustic_model_path=args.acoustic_model_path, corpus_directory=args.corpus_directory, - audio_directory=args.audio_directory, dictionary_path=args.dictionary_path, temporary_directory=args.temporary_directory, **PretrainedValidator.parse_parameters(args.config_path, args, unknown_args), @@ -38,7 +37,6 @@ def validate_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) - else: validator = TrainingValidator( corpus_directory=args.corpus_directory, - audio_directory=args.audio_directory, dictionary_path=args.dictionary_path, temporary_directory=args.temporary_directory, **TrainingValidator.parse_parameters(args.config_path, args, unknown_args), diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index 87f80b8f..0e47139a 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -27,9 +27,11 @@ ) from montreal_forced_aligner.corpus.helper import find_exts from montreal_forced_aligner.corpus.multiprocessing import CorpusProcessWorker +from montreal_forced_aligner.data import TextFileType from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin from montreal_forced_aligner.exceptions import TextGridParseError, TextParseError from montreal_forced_aligner.helper import load_scp +from montreal_forced_aligner.textgrid import parse_aligned_textgrid from montreal_forced_aligner.utils import KaldiProcessWorker, Stopped, thirdparty_binary __all__ = [ @@ -78,6 +80,77 @@ def __init__(self, audio_directory: Optional[str] = None, **kwargs): self.features_generated = False self.alignment_done = False self.transcription_done = False + self.has_reference_alignments = False + self.alignment_evaluation_done = False + + def _initialize_from_json(self, data): + self.features_generated = data.get("features_generated", False) + self.alignment_done = data.get("alignment_done", False) + self.transcription_done = data.get("transcription_done", False) + self.has_reference_alignments = data.get("has_reference_alignments", False) + self.alignment_evaluation_done = data.get("alignment_evaluation_done", False) + + @property + def corpus_meta(self): + return { + "features_generated": self.features_generated, + "alignment_done": self.alignment_done, + "transcription_done": self.transcription_done, + "has_reference_alignments": self.has_reference_alignments, + "alignment_evaluation_done": self.alignment_evaluation_done, + } + + def load_reference_alignments(self, reference_directory: str): + self.log_info("Loading reference files...") + indices = [] + jobs = [] + with tqdm.tqdm(total=len(self.files)) as pbar: + for root, _, files in os.walk(reference_directory, followlinks=True): + root_speaker = os.path.basename(root) + for f in files: + if f.endswith(".TextGrid"): + file_name = f.replace(".TextGrid", "") + if file_name not in self.files: + continue + if self.use_mp: + indices.append(file_name) + jobs.append((os.path.join(root, f), root_speaker)) + else: + file = self.files[file_name] + intervals = parse_aligned_textgrid(os.path.join(root, f), root_speaker) + for u in file.utterances: + if file.text_type == TextFileType.LAB: + for v in intervals.values(): + self.utterances[u.name].reference_phone_labels = v + else: + if u.speaker_name not in intervals: + continue + for interval in intervals[u.speaker_name]: + if interval.begin >= u.begin and interval.end <= u.end: + self.utterances[u.name].reference_phone_labels.append( + interval + ) + pbar.update(1) + if self.use_mp: + with mp.Pool(self.num_jobs) as pool: + gen = pool.starmap(parse_aligned_textgrid, jobs) + for i, intervals in enumerate(gen): + pbar.update(1) + file_name = indices[i] + file = self.files[file_name] + for u in file.utterances: + if file.text_type == TextFileType.LAB: + for v in intervals.values(): + self.utterances[u.name].reference_phone_labels = v + else: + if u.speaker_name not in intervals: + continue + for interval in intervals[u.speaker_name]: + if interval.begin >= u.begin and interval.end <= u.end: + self.utterances[u.name].reference_phone_labels.append( + interval + ) + self.has_reference_alignments = True def load_corpus(self) -> None: """ @@ -954,12 +1027,12 @@ def identifier(self) -> str: @property def output_directory(self) -> str: """Root temporary directory to store corpus and dictionary files""" - return os.path.join(self.temporary_directory, self.identifier) + return self.temporary_directory @property def working_directory(self) -> str: """Working directory to save temporary corpus and dictionary files""" - return self.output_directory + return self.corpus_output_directory class AcousticCorpusWithPronunciations( diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index cb60e397..8774b9a6 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -1,6 +1,7 @@ """Class definitions for corpora""" from __future__ import annotations +import json import os import random import time @@ -21,7 +22,7 @@ UtteranceCollection, ) from montreal_forced_aligner.corpus.multiprocessing import Job -from montreal_forced_aligner.data import SoundFileInformation +from montreal_forced_aligner.data import CtmInterval, SoundFileInformation from montreal_forced_aligner.exceptions import CorpusError from montreal_forced_aligner.helper import jsonl_encoder, output_mapping from montreal_forced_aligner.utils import Stopped @@ -100,6 +101,13 @@ def __init__( self.jobs: List[Job] = [] super().__init__(**kwargs) + def _initialize_from_json(self, data): + pass + + @property + def corpus_meta(self): + return {} + @property def features_directory(self) -> str: """Feature directory of the corpus""" @@ -124,6 +132,7 @@ def write_corpus_information(self) -> None: self._write_files() self._write_utterances() self._write_spk2utt() + self._write_corpus_info() def _write_spk2utt(self): """Write spk2utt scp file for Kaldi""" @@ -137,6 +146,13 @@ def write_utt2spk(self): data = {u.name: u.speaker.name for u in self.utterances} output_mapping(data, os.path.join(self.corpus_output_directory, "utt2spk.scp")) + def _write_corpus_info(self): + """Write speaker information for speeding up future runs""" + with open( + os.path.join(self.corpus_output_directory, "corpus.json"), "w", encoding="utf8" + ) as f: + json.dump(self.corpus_meta, f) + def _write_speakers(self): """Write speaker information for speeding up future runs""" with open( @@ -311,11 +327,7 @@ def create_subset(self, subset: int) -> None: larger_subset = sorted(self.utterances) random.seed(1234) # make it deterministic sampling subset_utts = UtteranceCollection() - try: - subset_utts.update(random.sample(larger_subset, subset)) - except ValueError: - print(subset, larger_subset_num, len(larger_subset)) - raise + subset_utts.update(random.sample(larger_subset, subset)) log_dir = os.path.join(subset_directory, "log") os.makedirs(log_dir, exist_ok=True) @@ -412,6 +424,7 @@ def _load_corpus_from_temp(self) -> bool: ) self.num_jobs = old_num_jobs format = "jsonl" + corpus_path = os.path.join(self.corpus_output_directory, "corpus.json") speakers_path = os.path.join(self.corpus_output_directory, "speakers.jsonl") files_path = os.path.join(self.corpus_output_directory, "files.jsonl") utterances_path = os.path.join(self.corpus_output_directory, "utterances.jsonl") @@ -432,6 +445,11 @@ def _load_corpus_from_temp(self) -> bool: return False self.log_debug("Loading from temporary files...") + if os.path.exists(corpus_path): + with open(corpus_path, "r", encoding="utf8") as f: + data = json.load(f) + self._initialize_from_json(data) + with open(speakers_path, "r", encoding="utf8") as f: if format == "jsonl": speaker_data = jsonlines.Reader(f) @@ -479,16 +497,27 @@ def _load_corpus_from_temp(self) -> bool: ) u.oovs = set(entry["oovs"]) u.normalized_text = entry["normalized_text"] - self.utterances[u.name] = u if u.text: self.word_counts.update(u.text.split()) if u.normalized_text: self.word_counts.update(u.normalized_text) - if entry.get("word_error_rate", None) is not None: - u.word_error_rate = entry["word_error_rate"] - u.transcription_text = entry["transcription_text"] - self.utterances[u.name].features = entry["features"] - self.utterances[u.name].ignored = entry["ignored"] + u.word_error_rate = entry.get("word_error_rate", None) + u.transcription_text = entry.get("transcription_text", None) + u.phone_error_rate = entry.get("phone_error_rate", None) + u.alignment_score = entry.get("alignment_score", None) + u.alignment_log_likelihood = entry.get("alignment_log_likelihood", None) + u.reference_phone_labels = [ + CtmInterval(**x) for x in entry.get("reference_phone_labels", []) + ] + + phone_labels = entry.get("phone_labels", None) + if phone_labels: + u.phone_labels = [CtmInterval(**x) for x in phone_labels] + word_labels = entry.get("word_labels", None) + if word_labels: + u.word_labels = [CtmInterval(**x) for x in word_labels] + u.features = entry.get("features", None) + u.ignored = entry.get("ignored", False) self.add_utterance(u) self.log_debug( diff --git a/montreal_forced_aligner/corpus/classes.py b/montreal_forced_aligner/corpus/classes.py index d4c75e8f..d18f1ea9 100644 --- a/montreal_forced_aligner/corpus/classes.py +++ b/montreal_forced_aligner/corpus/classes.py @@ -393,6 +393,7 @@ def save( output_directory: Optional[str] = None, backup_output_directory: Optional[str] = None, text_type: Optional[TextFileType] = None, + save_transcription: bool = False, ) -> None: """ Output File to TextGrid or lab. If ``text_type`` is not specified, the original file type will be used, @@ -408,6 +409,8 @@ def save( instead use this directory text_type: TextFileType, optional Text type to save as, if not provided, it will use either the original file type or guess the file type + save_transcription: bool + Flag for whether the hypothesized transcription text should be saved instead of the default text """ utterance_count = len(self.utterances) if text_type is None: @@ -418,18 +421,20 @@ def save( else: text_type = TextFileType.TEXTGRID if text_type == TextFileType.LAB: - if utterance_count == 0 and os.path.exists(self.text_path): + if utterance_count == 0 and os.path.exists(self.text_path) and not save_transcription: os.remove(self.text_path) return - utterance = next(iter(self.utterances)) + elif utterance_count == 0: + return output_path = self.construct_output_path( output_directory, backup_output_directory, enforce_lab=True ) with open(output_path, "w", encoding="utf8") as f: - if utterance.transcription_text is not None: - f.write(utterance.transcription_text) - else: - f.write(utterance.text) + for u in self.utterances: + if save_transcription: + f.write(u.transcription_text if u.transcription_text else "") + else: + f.write(u.text) return elif text_type == TextFileType.TEXTGRID: output_path = self.construct_output_path(output_directory, backup_output_directory) @@ -451,12 +456,14 @@ def save( speaker = utterance.speaker if not self.aligned: - if utterance.transcription_text is not None: + if save_transcription: tiers[speaker].entryList.append( Interval( start=utterance.begin, end=utterance.end, - label=utterance.transcription_text, + label=utterance.transcription_text + if utterance.transcription_text + else "", ) ) else: @@ -762,13 +769,10 @@ def normalized_waveform( np.abs(self.waveform[:, begin_sample:end_sample]), axis=0 ) y[np.isnan(y)] = 0 - y[0, :] += 3 - y[0, :] += 1 else: - y = ( - self.waveform[begin_sample:end_sample] - / np.max(np.abs(self.waveform[begin_sample:end_sample]), axis=0) - ) + 1 + y = self.waveform[begin_sample:end_sample] / np.max( + np.abs(self.waveform[begin_sample:end_sample]), axis=0 + ) x = np.arange(start=begin_sample, stop=end_sample) / self.sample_rate return x, y @@ -854,10 +858,14 @@ def __init__( self.features = None self.phone_labels: Optional[List[CtmInterval]] = None self.word_labels: Optional[List[CtmInterval]] = None + self.reference_phone_labels: Optional[List[CtmInterval]] = [] self.oovs = set() self.normalized_text = [] self.text_int = [] + self.alignment_log_likelihood = None self.word_error_rate = None + self.phone_error_rate = None + self.alignment_score = None def parse_transcription(self, sanitize_function=Optional[MultispeakerSanitizationFunction]): """ @@ -879,13 +887,15 @@ def parse_transcription(self, sanitize_function=Optional[MultispeakerSanitizatio words = [ sanitize(w) for w in self.text.split() - if w not in sanitize.clitic_markers + sanitize.compound_markers + if w and w not in sanitize.clitic_markers + sanitize.compound_markers ] self.text = " ".join(words) if split is not None: for w in words: for new_w in split(w): - if new_w not in split.word_set: + if not new_w: + continue + if split.word_set is not None and new_w not in split.word_set: self.oovs.add(new_w) self.normalized_text.append(new_w) @@ -975,7 +985,13 @@ def meta(self) -> Dict[str, Any]: "normalized_text": self.normalized_text, "oovs": self.oovs, "transcription_text": self.transcription_text, + "reference_phone_labels": self.reference_phone_labels, + "phone_labels": self.phone_labels, + "word_labels": self.word_labels, "word_error_rate": self.word_error_rate, + "phone_error_rate": self.phone_error_rate, + "alignment_score": self.alignment_score, + "alignment_log_likelihood": self.alignment_log_likelihood, } def set_speaker(self, speaker: Speaker) -> None: @@ -1012,7 +1028,7 @@ def add_word_intervals(self, intervals: Union[CtmInterval, List[CtmInterval]]) - for interval in intervals: if self.begin is not None: interval.shift_times(self.begin) - self.word_labels.extend(intervals) + self.word_labels = intervals def add_phone_intervals(self, intervals: Union[CtmInterval, List[CtmInterval]]) -> None: """ @@ -1030,7 +1046,7 @@ def add_phone_intervals(self, intervals: Union[CtmInterval, List[CtmInterval]]) for interval in intervals: if self.begin is not None: interval.shift_times(self.begin) - self.phone_labels.extend(intervals) + self.phone_labels = intervals def text_for_scp(self) -> List[str]: """ diff --git a/montreal_forced_aligner/corpus/text_corpus.py b/montreal_forced_aligner/corpus/text_corpus.py index 4b5f6d68..69f41e15 100644 --- a/montreal_forced_aligner/corpus/text_corpus.py +++ b/montreal_forced_aligner/corpus/text_corpus.py @@ -33,9 +33,10 @@ def _load_corpus_from_source_mp(self) -> None: """ if self.stopped is None: self.stopped = Stopped() - sanitize_function = None - if hasattr(self, "construct_sanitize_function"): - sanitize_function = self.construct_sanitize_function() + try: + sanitize_function = self.sanitize_function + except AttributeError: + sanitize_function = None begin_time = time.time() manager = mp.Manager() job_queue = manager.Queue() @@ -69,7 +70,6 @@ def _load_corpus_from_source_mp(self) -> None: if self.stopped.stop_check(): break wav_path = None - transcription_path = None if file_name in exts.lab_files: lab_name = exts.lab_files[file_name] transcription_path = os.path.join(root, lab_name) @@ -77,6 +77,8 @@ def _load_corpus_from_source_mp(self) -> None: elif file_name in exts.textgrid_files: tg_name = exts.textgrid_files[file_name] transcription_path = os.path.join(root, tg_name) + else: + continue job_queue.put((file_name, wav_path, transcription_path, relative_path)) finished_adding.stop() @@ -160,9 +162,10 @@ def _load_corpus_from_source(self) -> None: begin_time = time.time() self.stopped = False - sanitize_function = None - if hasattr(self, "construct_sanitize_function"): - sanitize_function = self.construct_sanitize_function() + try: + sanitize_function = self.sanitize_function + except AttributeError: + sanitize_function = None for root, _, files in os.walk(self.corpus_directory, followlinks=True): exts = find_exts(files) relative_path = root.replace(self.corpus_directory, "").lstrip("/").lstrip("\\") @@ -171,14 +174,14 @@ def _load_corpus_from_source(self) -> None: for file_name in exts.identifiers: wav_path = None - transcription_path = None if file_name in exts.lab_files: lab_name = exts.lab_files[file_name] transcription_path = os.path.join(root, lab_name) elif file_name in exts.textgrid_files: tg_name = exts.textgrid_files[file_name] transcription_path = os.path.join(root, tg_name) - + else: + continue try: file = File.parse_file( file_name, diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index 3e0bfda6..69f17a95 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -377,7 +377,6 @@ def extra_questions(self) -> typing.Dict[str, typing.Set[str]]: extra_questions["dental_lenition"] = voiced_variants("ð") | voiced_variants("d") extra_questions["flapping"] = {"d", "t", "ɾ"} extra_questions["glottalization"] = {"t", "ʔ", "t̚"} - extra_questions["glottal_variation"] = self.vowels | {"ʔ"} extra_questions["labial_lenition"] = voiced_variants("β") | voiced_variants("b") extra_questions["velar_lenition"] = voiced_variants("ɣ") | voiced_variants("ɡ") @@ -843,6 +842,9 @@ class CtmInterval: label: str utterance: str + def __lt__(self, other: CtmInterval): + return self.begin < other.begin + def __post_init__(self): """ Check on data validity diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py index 21fcbfdf..c0df0a9a 100644 --- a/montreal_forced_aligner/dictionary/mixins.py +++ b/montreal_forced_aligner/dictionary/mixins.py @@ -180,7 +180,7 @@ def split_clitics( split.append(clitic) except IndexError: pass - if not any(x in self.word_set for x in split): + if self.word_set is not None and not any(x in self.word_set for x in split): return [item] return split diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index b624ee6e..6c9a6eb2 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -69,8 +69,12 @@ def get_functions_for_speaker( :class:`~montreal_forced_aligner.dictionary.mixins.SplitWordsFunction` Function for splitting up words """ - dict_name = self.get_dict_name_for_speaker(speaker_name) - return self.sanitize_function, self.split_functions[dict_name] + try: + dict_name = self.get_dict_name_for_speaker(speaker_name) + split_function = self.split_functions[dict_name] + except KeyError: + split_function = None + return self.sanitize_function, split_function class MultispeakerDictionaryMixin(TemporaryDictionaryMixin, metaclass=abc.ABCMeta): @@ -122,7 +126,7 @@ def sanitize_function(self) -> MultispeakerSanitizationFunction: self.clitic_markers, self.compound_markers, dictionary.clitic_set, - set(dictionary.words.keys()), + set(dictionary.actual_words.keys()), ) return MultispeakerSanitizationFunction( self.speaker_mapping, sanitize_function, split_functions @@ -238,7 +242,7 @@ def write_lexicon_information(self, write_disambiguation: Optional[bool] = False self._write_topo() self._write_extra_questions() for d in self.dictionary_mapping.values(): - d.write(write_disambiguation) + d.write(write_disambiguation, debug=getattr(self, "debug", False)) def set_lexicon_word_set(self, word_set: Collection[str]) -> None: """ diff --git a/montreal_forced_aligner/dictionary/pronunciation.py b/montreal_forced_aligner/dictionary/pronunciation.py index 31a2f454..a99bbae3 100644 --- a/montreal_forced_aligner/dictionary/pronunciation.py +++ b/montreal_forced_aligner/dictionary/pronunciation.py @@ -381,7 +381,7 @@ def to_int(self, item: str, normalized=False) -> List[int]: if item == "": return [] if normalized: - if item in self.words_mapping: + if item in self.words_mapping and item not in self.specials_set: return [self.words_mapping[item]] else: return [self.oov_int] @@ -636,10 +636,10 @@ def _write_fst_binary( if write_disambiguation: temp2_fst_path = os.path.join(self.dictionary_output_directory, "temp2.fst") word_disambig_path = os.path.join( - self.dictionary_output_directory, "word_disambig0.txt" + self.dictionary_output_directory, "word_disambig.txt" ) phone_disambig_path = os.path.join( - self.dictionary_output_directory, "phone_disambig0.txt" + self.dictionary_output_directory, "phone_disambig.txt" ) with open(phone_disambig_path, "w") as f: f.write(str(self.phone_mapping["#0"])) diff --git a/montreal_forced_aligner/helper.py b/montreal_forced_aligner/helper.py index 1e71af06..812ffe66 100644 --- a/montreal_forced_aligner/helper.py +++ b/montreal_forced_aligner/helper.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import dataclasses import functools import itertools import json @@ -16,7 +17,7 @@ from colorama import Fore, Style if TYPE_CHECKING: - from montreal_forced_aligner.abc import CorpusMappingType, Labels, MetaDict, ScpType + from montreal_forced_aligner.abc import CorpusMappingType, MetaDict, ScpType from montreal_forced_aligner.dictionary.pronunciation import Word from montreal_forced_aligner.textgrid import CtmInterval @@ -618,7 +619,7 @@ def load_scp(path: str, data_type: Optional[Type] = str) -> CorpusMappingType: return scp -def edit_distance(x: Labels, y: Labels) -> int: +def edit_distance(x: List[str], y: List[str]) -> int: """ Compute edit distance between two sets of labels @@ -629,9 +630,9 @@ def edit_distance(x: Labels, y: Labels) -> int: Parameters ---------- - x: Labels + x: list[str] First sequence to compare - y: Labels + y: list[str] Second sequence to compare Returns @@ -692,15 +693,15 @@ def score_g2p(gold: Word, hypo: Word) -> Tuple[int, int]: return edits, best_length -def score(gold: Labels, hypo: Labels, multiple_hypotheses=False) -> Tuple[int, int]: +def score(gold: List[str], hypo: List[str]) -> Tuple[int, int]: """ Computes sufficient statistics for LER calculation. Parameters ---------- - gold: Labels + gold: list[str] The reference labels - hypo: Labels + hypo: list[str] The hypothesized labels multiple_hypotheses: bool Flag for whether the hypotheses contain multiple @@ -712,16 +713,7 @@ def score(gold: Labels, hypo: Labels, multiple_hypotheses=False) -> Tuple[int, i int Length of the gold labels """ - if multiple_hypotheses: - edits = 100000 - for h in hypo: - e = edit_distance(gold, h) - if e < edits: - edits = e - if not edits: - break - else: - edits = edit_distance(gold, hypo) + edits = edit_distance(gold, hypo) return edits, len(gold) @@ -800,15 +792,19 @@ def overlap_scoring( return -1 * (begin_diff + end_diff + label_diff) -def set_default(obj): +class EnhancedJSONEncoder(json.JSONEncoder): """JSON serialization""" - if isinstance(obj, set): - return list(obj) - raise TypeError + + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, set): + return list(o) + return dataclasses.asdict(o) def jsonl_encoder(obj): - return json.dumps(obj, default=set_default) + return json.dumps(obj, cls=EnhancedJSONEncoder) def align_phones( diff --git a/montreal_forced_aligner/ivector/trainer.py b/montreal_forced_aligner/ivector/trainer.py index 87ece018..0576b6a0 100644 --- a/montreal_forced_aligner/ivector/trainer.py +++ b/montreal_forced_aligner/ivector/trainer.py @@ -39,6 +39,21 @@ class IvectorModelTrainingMixin(AcousticModelTrainingMixin): For acoustic model training parsing parameters """ + @property + def meta(self) -> MetaDict: + """Generate metadata for the acoustic model that was trained""" + from datetime import datetime + + from ..utils import get_mfa_version + + data = { + "version": get_mfa_version(), + "architecture": self.architecture, + "train_date": str(datetime.now()), + "features": self.feature_options, + } + return data + def export_model(self, output_model_path: str) -> None: """ Output IvectorExtractor model diff --git a/montreal_forced_aligner/language_modeling/trainer.py b/montreal_forced_aligner/language_modeling/trainer.py index 7995c88f..eddc399f 100644 --- a/montreal_forced_aligner/language_modeling/trainer.py +++ b/montreal_forced_aligner/language_modeling/trainer.py @@ -99,7 +99,7 @@ def prune_large_language_model(self) -> None: self.log_info("Pruning large ngram model to medium and small versions...") small_mod_path = self.mod_path.replace(".mod", "_small.mod") med_mod_path = self.mod_path.replace(".mod", "_med.mod") - subprocess.call( + subprocess.check_call( [ "ngramshrink", f"--method={self.prune_method}", @@ -108,10 +108,12 @@ def prune_large_language_model(self) -> None: med_mod_path, ] ) - subprocess.call(["ngramprint", "--ARPA", med_mod_path, self.medium_arpa_path]) + assert os.path.exists(med_mod_path) + subprocess.check_call(["ngramprint", "--ARPA", med_mod_path, self.medium_arpa_path]) + assert os.path.exists(self.medium_arpa_path) self.log_debug("Finished pruning medium arpa!") - subprocess.call( + subprocess.check_call( [ "ngramshrink", f"--method={self.prune_method}", @@ -120,7 +122,9 @@ def prune_large_language_model(self) -> None: small_mod_path, ] ) - subprocess.call(["ngramprint", "--ARPA", small_mod_path, self.small_arpa_path]) + assert os.path.exists(small_mod_path) + subprocess.check_call(["ngramprint", "--ARPA", small_mod_path, self.small_arpa_path]) + assert os.path.exists(self.small_arpa_path) self.log_debug("Finished pruning small arpa!") self.log_info("Done pruning!") @@ -188,7 +192,7 @@ def setup(self) -> None: self.save_oovs_found(self.working_directory) subprocess.call( - ["ngramsymbols", f'--OOV_symbol="{self.oov_word}"', self.training_path, self.sym_path] + ["ngramsymbols", f"--OOV_symbol={self.oov_word}", self.training_path, self.sym_path] ) self.initialized = True @@ -241,7 +245,7 @@ def evaluate(self) -> None: perplexity_proc = subprocess.Popen( [ "ngramperplexity", - f'--OOV_symbol="{self.oov_word}"', + f"--OOV_symbol={self.oov_word}", self.mod_path, self.far_path, ], @@ -274,7 +278,7 @@ def evaluate(self) -> None: perplexity_proc = subprocess.Popen( [ "ngramperplexity", - f'--OOV_symbol="{self.oov_word}"', + f"--OOV_symbol={self.oov_word}", med_mod_path, self.far_path, ], @@ -293,7 +297,7 @@ def evaluate(self) -> None: perplexity_proc = subprocess.Popen( [ "ngramperplexity", - f'--OOV_symbol="{self.oov_word}"', + f"--OOV_symbol={self.oov_word}", small_mod_path, self.far_path, ], @@ -326,49 +330,41 @@ def normalized_text_iter(self, min_count: int = 1) -> Generator: """ unk_words = {k for k, v in self.word_counts.items() if v <= min_count} for u in self.utterances: - text = u.text.split() - new_text = [] - for t in text: - if u.speaker.dictionary is not None: - u.speaker.dictionary.to_int(t) - lookup = u.speaker.dictionary.split_clitics(t) - if lookup is None: - continue - else: - lookup = [t] - for item in lookup: - if item in unk_words: - new_text.append(self.oov_word) - self.oovs_found[item] += 1 - elif ( - u.speaker.dictionary is not None and item not in u.speaker.dictionary.words - ): - new_text.append(self.oov_word) - else: - new_text.append(item) - yield " ".join(new_text) + normalized = u.normalized_text + if normalized: + normalized = u.text.split() + yield " ".join(x if x not in unk_words else self.oov_word for x in normalized) def train(self) -> None: """ Train a language model """ self.log_info("Beginning training large ngram model...") - subprocess.call( + subprocess.check_call( [ "farcompilestrings", "--fst_type=compact", - f'--unknown_symbol="{self.oov_word}"', + f"--unknown_symbol={self.oov_word}", f"--symbols={self.sym_path}", "--keep_symbols", self.training_path, self.far_path, ] ) - subprocess.call(["ngramcount", f"--order={self.order}", self.far_path, self.cnts_path]) - subprocess.call(["ngrammake", f"--method={self.method}", self.cnts_path, self.mod_path]) + assert os.path.exists(self.far_path) + subprocess.check_call( + ["ngramcount", f"--order={self.order}", self.far_path, self.cnts_path] + ) + + assert os.path.exists(self.cnts_path) + subprocess.check_call( + ["ngrammake", f"--method={self.method}", self.cnts_path, self.mod_path] + ) + assert os.path.exists(self.mod_path) self.log_info("Done!") - subprocess.call(["ngramprint", "--ARPA", self.mod_path, self.large_arpa_path]) + subprocess.check_call(["ngramprint", "--ARPA", self.mod_path, self.large_arpa_path]) + assert os.path.exists(self.large_arpa_path) self.log_info("Large ngam model created!") @@ -388,7 +384,28 @@ class LmDictionaryCorpusTrainer(MultispeakerDictionaryMixin, LmCorpusTrainer): For dictionary parsing parameters """ - pass + def setup(self) -> None: + """Set up language model training""" + if self.initialized: + return + os.makedirs(self.working_log_directory, exist_ok=True) + self.dictionary_setup() + self._load_corpus() + self.set_lexicon_word_set(self.corpus_word_set) + self.write_lexicon_information() + + with open(self.training_path, "w", encoding="utf8") as f: + for text in self.normalized_text_iter(self.count_threshold): + f.write(f"{text}\n") + + self.save_oovs_found(self.working_directory) + + self.initialized = True + + @property + def sym_path(self): + """Internal path to symbols file""" + return os.path.join(self.default_dictionary.dictionary_output_directory, "words.txt") class LmArpaTrainer(LmTrainerMixin, TopLevelMfaWorker): @@ -443,9 +460,10 @@ def train(self) -> None: with open( os.path.join(self.working_log_directory, "read.log"), "w", encoding="utf8" ) as log_file: - subprocess.call( + subprocess.check_call( ["ngramread", "--ARPA", self.large_arpa_path, self.mod_path], stderr=log_file ) + assert os.path.exists(self.mod_path) self.log_info("Large ngam model parsed!") diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index dd3e0be2..3c288826 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -21,7 +21,7 @@ ModelLoadError, PronunciationAcousticMismatchError, ) -from montreal_forced_aligner.helper import TerminalPrinter, set_default +from montreal_forced_aligner.helper import EnhancedJSONEncoder, TerminalPrinter if TYPE_CHECKING: from logging import Logger @@ -198,9 +198,16 @@ def meta(self) -> dict: Get the meta data associated with the model """ if not self._meta: - meta_path = os.path.join(self.dirname, "meta.yaml") + meta_path = os.path.join(self.dirname, "meta.json") + format = "json" + if not os.path.exists(meta_path): + meta_path = os.path.join(self.dirname, "meta.yaml") + format = "yaml" with open(meta_path, "r", encoding="utf8") as f: - self._meta = yaml.safe_load(f) + if format == "yaml": + self._meta = yaml.safe_load(f) + else: + self._meta = json.load(f) self.parse_old_features() return self._meta @@ -213,8 +220,8 @@ def add_meta_file(self, trainer: ModelExporterMixin) -> None: trainer: :class:`~montreal_forced_aligner.abc.ModelExporterMixin` The trainer to construct the metadata from """ - with open(os.path.join(self.dirname, "meta.yaml"), "w", encoding="utf8") as f: - yaml.dump(trainer.meta, f) + with open(os.path.join(self.dirname, "meta.json"), "w", encoding="utf8") as f: + json.dump(trainer.meta, f) @classmethod def empty( @@ -310,8 +317,8 @@ def add_meta_file(self, trainer: ModelExporterMixin) -> None: trainer: :class:`~montreal_forced_aligner.abc.ModelExporterMixin` Trainer to supply metadata information about the acoustic model """ - with open(os.path.join(self.dirname, "meta.yaml"), "w", encoding="utf8") as f: - yaml.dump(trainer.meta, f) + with open(os.path.join(self.dirname, "meta.json"), "w", encoding="utf8") as f: + json.dump(trainer.meta, f) @property def parameters(self) -> MetaDict: @@ -351,7 +358,11 @@ def meta(self) -> MetaDict: "splice_right_context": 3, } if not self._meta: - meta_path = os.path.join(self.dirname, "meta.yaml") + meta_path = os.path.join(self.dirname, "meta.json") + format = "json" + if not os.path.exists(meta_path): + meta_path = os.path.join(self.dirname, "meta.yaml") + format = "yaml" if not os.path.exists(meta_path): self._meta = { "version": "0.9.0", @@ -360,7 +371,10 @@ def meta(self) -> MetaDict: } else: with open(meta_path, "r", encoding="utf8") as f: - self._meta = yaml.safe_load(f) + if format == "yaml": + self._meta = yaml.safe_load(f) + else: + self._meta = json.load(f) if self._meta["features"] == "mfcc+deltas": self._meta["features"] = default_features if "phone_type" not in self._meta: @@ -468,7 +482,9 @@ def log_details(self, logger: Logger) -> None: logger.debug("====ACOUSTIC MODEL INFO====") logger.debug("Acoustic model root directory: " + self.root_directory) logger.debug("Acoustic model dirname: " + self.dirname) - meta_path = os.path.join(self.dirname, "meta.yaml") + meta_path = os.path.join(self.dirname, "meta.json") + if not os.path.exists(meta_path): + meta_path = os.path.join(self.dirname, "meta.yaml") logger.debug("Acoustic model meta path: " + meta_path) if not os.path.exists(meta_path): logger.debug("META.YAML DOES NOT EXIST, this may cause issues in validating the model") @@ -581,7 +597,7 @@ def add_meta_file(self, g2p_trainer: G2PTrainer) -> None: """ with open(os.path.join(self.dirname, "meta.json"), "w", encoding="utf8") as f: - json.dump(g2p_trainer.meta, f, default=set_default) + json.dump(g2p_trainer.meta, f, cls=EnhancedJSONEncoder) @property def meta(self) -> dict: diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index d63af426..6de63172 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import csv import itertools import multiprocessing as mp import os @@ -1330,52 +1331,69 @@ def evaluate(self): self._load_transcripts() # Sentence-level measures - correct = 0 incorrect = 0 + total_count = 0 # Word-level measures total_edits = 0 total_length = 0 issues = {} indices = [] + to_comp = [] + for utterance in self.utterances: + utt_name = utterance.name + if not utterance.text: + continue + total_count += 1 + g = utterance.text.split() + total_length += len(g) + if not utterance.transcription_text: + incorrect += 1 + total_edits += len(g) + issues[utt_name] = [g, "", 1] + continue + + h = utterance.transcription_text.split() + if g != h: + issues[utt_name] = [g, h] + indices.append(utt_name) + to_comp.append((g, h)) + incorrect += 1 + else: + issues[utt_name] = [g, h, 0] with mp.Pool(self.num_jobs) as pool: - to_comp = [] - for utterance in self.utterances: - utt_name = utterance.name - if not utterance.text: - continue - g = utterance.text.split() - if not utterance.transcription_text: - incorrect += 1 - total_edits += len(g) - total_length += len(g) - issues[utt_name] = [g, "", 1] - continue - - h = utterance.transcription_text.split() - if g != h: - issues[utt_name] = [g, h] - indices.append(utt_name) - to_comp.append((g, h)) - else: - issues[utt_name] = [g, h, 0] gen = pool.starmap(score, to_comp) for i, (edits, length) in enumerate(gen): issues[indices[i]].append(edits / length) - if edits == 0: - correct += 1 - else: - incorrect += 1 total_edits += edits - total_length += length output_path = os.path.join(self.evaluation_directory, "transcription_evaluation.csv") - with open(output_path, "w", encoding="utf8") as f: - f.write("utterance,gold_transcript,hypothesis,WER\n") - for utt, (g, h, wer) in issues.items(): - self.utterances[utt].word_error_rate = wer + with open(output_path, "w", newline="", encoding="utf8") as f: + writer = csv.writer(f) + writer.writerow( + [ + "utterance", + "file", + "speaker", + "duration", + "word_count", + "oov_count", + "gold_transcript", + "hypothesis", + "WER", + ] + ) + for utt in sorted(issues.keys()): + g, h, wer = issues[utt] + utterance = self.utterances[utt] + utterance.word_error_rate = wer + speaker = utterance.speaker_name + file = utterance.file_name + duration = utterance.duration + word_count = len(utterance.text.split()) + oov_count = len(utterance.oovs) g = " ".join(g) h = " ".join(h) - f.write(f"{utt},{g},{h},{wer}\n") - ser = 100 * incorrect / (correct + incorrect) + writer.writerow([utt, file, speaker, duration, word_count, oov_count, g, h, wer]) + ser = 100 * incorrect / total_count wer = 100 * total_edits / total_length self.logger.info(f"SER: {ser:.2f}%, WER: {wer:.2f}%") return ser, wer @@ -1415,7 +1433,9 @@ def export_files(self, output_directory: str) -> None: os.makedirs(backup_output_directory, exist_ok=True) self._load_transcripts() for file in self.files: - file.save(output_directory, backup_output_directory) + if len(file.utterances) == 0: + self.logger.debug(f"Could not find any utterances for {file.name}") + file.save(output_directory, backup_output_directory, save_transcription=True) if self.evaluation_mode: shutil.copyfile( os.path.join(self.evaluation_directory, "transcription_evaluation.csv"), diff --git a/tests/conftest.py b/tests/conftest.py index 063f2b25..f4aad4d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -211,23 +211,43 @@ def basic_split_dir(corpus_root_dir, wav_dir, lab_dir, textgrid_dir): audio_path = os.path.join(path, "audio") text_path = os.path.join(path, "text") os.makedirs(path, exist_ok=True) - names = [("michael", ["acoustic_corpus"]), ("sickmichael", ["cold_corpus", "cold_corpus3"])] + names = [ + ("michael", ["acoustic_corpus"]), + ("sickmichael", ["cold_corpus", "cold_corpus3"]), + ( + "speaker", + [ + "multilingual_ipa", + "multilingual_ipa_2", + "multilingual_ipa_3", + "multilingual_ipa_4", + "multilingual_ipa_5", + ], + ), + ( + "speaker_two", + [ + "multilingual_ipa_us", + "multilingual_ipa_us_2", + "multilingual_ipa_us_3", + "multilingual_ipa_us_4", + "multilingual_ipa_us_5", + ], + ), + ] for s, files in names: s_text_dir = os.path.join(text_path, s) s_audio_dir = os.path.join(audio_path, s) os.makedirs(s_text_dir, exist_ok=True) os.makedirs(s_audio_dir, exist_ok=True) for name in files: - shutil.copyfile( - os.path.join(wav_dir, name + ".wav"), os.path.join(s_audio_dir, name + ".wav") - ) - shutil.copyfile( - os.path.join(lab_dir, name + ".lab"), os.path.join(s_text_dir, name + ".lab") - ) - shutil.copyfile( - os.path.join(textgrid_dir, "acoustic_corpus.TextGrid"), - os.path.join(s_text_dir, "acoustic_corpus_nonsense.TextGrid"), - ) + wav_path = os.path.join(wav_dir, name + ".wav") + if os.path.exists(wav_path): + shutil.copyfile(wav_path, wav_path.replace(wav_dir, s_audio_dir)) + lab_path = os.path.join(lab_dir, name + ".lab") + if not os.path.exists(lab_path): + lab_path = os.path.join(lab_dir, name + ".txt") + shutil.copyfile(lab_path, lab_path.replace(lab_dir, s_text_dir)) return audio_path, text_path diff --git a/tests/data/lab/weird_words.lab b/tests/data/lab/weird_words.lab index 620999af..279da01f 100644 --- a/tests/data/lab/weird_words.lab +++ b/tests/data/lab/weird_words.lab @@ -1 +1 @@ -i’m talking-ajfish me-really asds-asda sdasd-me +i’m talking-ajfish me-really asds-asda sdasd-me diff --git a/tests/test_commandline_lm.py b/tests/test_commandline_lm.py index 6915018e..79cc28e5 100644 --- a/tests/test_commandline_lm.py +++ b/tests/test_commandline_lm.py @@ -49,7 +49,7 @@ def test_train_lm_text(basic_split_dir, temp_dir, generated_dir, basic_train_lm_ def test_train_lm_dictionary( - basic_split_dir, basic_dict_path, temp_dir, generated_dir, basic_train_lm_config_path + basic_split_dir, sick_dict_path, temp_dir, generated_dir, basic_train_lm_config_path ): if sys.platform == "win32": pytest.skip("LM training not supported on Windows.") @@ -62,7 +62,7 @@ def test_train_lm_dictionary( "-t", temp_dir, "--dictionary_path", - basic_dict_path, + sick_dict_path, "--config_path", basic_train_lm_config_path, "-q", diff --git a/tests/test_corpus.py b/tests/test_corpus.py index d67439bb..50de0852 100644 --- a/tests/test_corpus.py +++ b/tests/test_corpus.py @@ -409,8 +409,14 @@ def test_weird_words(weird_words_dir, generated_dir, sick_dict_path): "ajfish", "asds-asda", "sdasd", + "", + "", } - + print(corpus.utterances["weird-words-weird-words-0-26-72325"].text_int_for_scp()) + assert ( + corpus.utterances["weird-words-weird-words-0-26-72325"].text_int_for_scp()[-1] + == corpus.default_dictionary.oov_int + ) corpus.set_lexicon_word_set(corpus.corpus_word_set) for w in ["i'm", "this'm", "sdsdsds'm", "'m"]: _ = corpus.default_dictionary.to_int(w) @@ -501,13 +507,15 @@ def test_no_punctuation(punctuated_dir, generated_dir, sick_dict_path, no_punctu "mean...", ] weird_words = corpus.utterances["punctuated-weird-words-0-26-72325"] - assert weird_words.text == "i’m talking-ajfish me-really asds-asda sdasd-me" + assert weird_words.text == "i’m talking-ajfish me-really asds-asda sdasd-me " assert weird_words.normalized_text == [ "i’m", "talking-ajfish", "me-really", "asds-asda", "sdasd-me", + "", + "", ]