diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 649b9476..23f3789f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,7 +20,7 @@ jobs: - uses: "actions/checkout@v2" - uses: "actions/setup-python@v2" with: - python-version: "3.8" + python-version: "3.9" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index 3b0108c2..672c58a0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,8 +11,7 @@ report.txt .DS_Store -tests/data/generated -docs/source/generated +generated/ pretrained_models/ @@ -83,5 +82,6 @@ docs/source/api/ montreal_forced_aligner/_version.py /docs/source/reference/generated/ +<<<<<<< main docs/source/reference/multiprocessing/generated/ diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 2e1a4ff3..58711083 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 96.8% + interrogate: 98.8% @@ -12,8 +12,8 @@ interrogate interrogate - 96.8% - 96.8% + 98.8% + 98.8% diff --git a/docs/source/_templates/class_b.rst b/docs/source/_templates/class_b.rst deleted file mode 100644 index 9b12af20..00000000 --- a/docs/source/_templates/class_b.rst +++ /dev/null @@ -1,34 +0,0 @@ -{{ objname }} -{{ underline }} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - - {% block attributes %} - - {% if attributes %} - .. rubric:: Attributes - - .. autosummary:: - {% for item in attributes %} - {% if item != '__init__' %} - ~{{ name }}.{{ item }} - {% endif %} - {% endfor %} - {% endif %} - {% endblock %} - - {% block methods %} - - {% if methods %} - .. rubric:: Methods - - .. autosummary:: - {% for item in methods %} - {% if item != '__init__' %} - ~{{ name }}.{{ item }} - {% endif %} - {% endfor %} - {% endif %} - {% endblock %} diff --git a/docs/source/_templates/function_b.rst b/docs/source/_templates/function_b.rst deleted file mode 100644 index 32c17bc1..00000000 --- a/docs/source/_templates/function_b.rst +++ /dev/null @@ -1,10 +0,0 @@ -{{objname}} -{{ underline }} - -.. currentmodule:: {{ module }} - -.. autofunction:: {{ objname }} - -.. raw:: html - -
diff --git a/docs/source/changelog/changelog_2.0.rst b/docs/source/changelog/changelog_2.0.rst index f843cbba..171a6dd3 100644 --- a/docs/source/changelog/changelog_2.0.rst +++ b/docs/source/changelog/changelog_2.0.rst @@ -10,6 +10,18 @@ Beta releases ============= +2.0.0b8 +------- + +- Refactored internal organization to rely on mixins more than monolithic classes, and moved internal functions to be organized by what they're used for instead of the general type. + + - For instance, there used to be a ``montreal_forced_aligner.multiprocessing`` module with ``alignment.py``, ``transcription.py``, etc that all did multiprocessing for various workers. Now that functionality is located closer to where it's used, i.e. ``montreal_forced_aligner.transcription.multiprocessing``. + - Mixins should allow for more easy extension to new use cases and allow for better configuration + +- Updated documentation to reflect the refactoring and did a pass over the User Guide +- Added the ability to change the location of root MFA directory based on the ``MFA_ROOT_DIR`` environment variable +- Fixed an issue where the version was incorrectly reported as "2.0.0" + 2.0.0b5 ------- @@ -23,8 +35,8 @@ Beta releases - Massive refactor to a proper class-based API for interacting with MFA corpora - Sorry, I really do hope this is the last big refactor of 2.0 - - :class:`~montreal_forced_aligner.corpus.Speaker`, :class:`~montreal_forced_aligner.corpus.File`, and :class:`~montreal_forced_aligner.corpus.Utterance` have dedicated classes rather than having their information split across dictionaries mimicking Kaldi files, so they should be more useful for interacting with outside of MFA - - Added :class:`~montreal_forced_aligner.multiprocessing.Job` class as well to make it easier to generate and keep track of information about different processes + - :class:`~montreal_forced_aligner.corpus.classes.Speaker`, :class:`~montreal_forced_aligner.corpus.classes.File`, and :class:`~montreal_forced_aligner.corpus.classes.Utterance` have dedicated classes rather than having their information split across dictionaries mimicking Kaldi files, so they should be more useful for interacting with outside of MFA + - Added :class:`~montreal_forced_aligner.corpus.multiprocessing.Job` class as well to make it easier to generate and keep track of information about different processes - Updated installation style to be more dependent on conda-forge packages - Kaldi and MFA are now on conda-forge! |:tada:| diff --git a/docs/source/conf.py b/docs/source/conf.py index 956650da..280ecaa8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -61,7 +61,10 @@ xref_links = { "mfa_mailing_list": ("MFA mailing list", "https://groups.google.com/g/mfa-users"), - "mfa_github": ("MFA GitHub Repo", "https://groups.google.com/g/mfa-users"), + "mfa_github": ( + "MFA GitHub Repo", + "https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner", + ), "mfa_github_issues": ( "MFA GitHub Issues", "https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/issues", @@ -79,6 +82,9 @@ "kaldi_github": ("Kaldi GitHub", "https://github.com/kaldi-asr/kaldi"), "htk": ("HTK", "http://htk.eng.cam.ac.uk/"), "phonetisaurus": ("Phonetisaurus", "https://github.com/AdolfVonKleist/Phonetisaurus"), + "opengrm_ngram": ("OpenGrm-NGram", "https://www.openfst.org/twiki/bin/view/GRM/NGramLibrary"), + "openfst": ("OpenFst", "https://www.openfst.org/twiki/bin/view/FST"), + "baumwelch": ("Baum-Welch", "https://www.opengrm.org/twiki/bin/view/GRM/BaumWelch"), "pynini": ("Pynini", "https://www.openfst.org/twiki/bin/view/GRM/Pynini"), "prosodylab_aligner": ("Prosodylab-aligner", "http://prosodylab.org/tools/aligner/"), "p2fa": ( @@ -126,19 +132,22 @@ "Trainer": "montreal_forced_aligner.abc.Trainer", "Aligner": "montreal_forced_aligner.abc.Aligner", "DictionaryData": "montreal_forced_aligner.dictionary.DictionaryData", - "Utterance": "montreal_forced_aligner.corpus.Utterance", - "File": "montreal_forced_aligner.corpus.File", + "Utterance": "montreal_forced_aligner.corpus.classes.Utterance", + "File": "montreal_forced_aligner.corpus.classes.File", "FeatureConfig": "montreal_forced_aligner.config.FeatureConfig", "multiprocessing.context.Process": "multiprocessing.Process", "mp.Process": "multiprocessing.Process", - "Speaker": "montreal_forced_aligner.corpus.Speaker", + "Speaker": "montreal_forced_aligner.corpus.classes.Speaker", + "Namespace": "argparse.Namespace", + "MetaDict": "dict[str, Any]", } napoleon_preprocess_types = False napoleon_attr_annotations = False napoleon_use_param = True +napoleon_use_ivar = True napoleon_type_aliases = { - "Labels": "List[str]", + "Labels": "list[str]", } typehints_fully_qualified = False # numpydoc_xref_param_type = True @@ -222,13 +231,13 @@ nitpick_ignore = [ ("py:class", "optional"), ("py:class", "callable"), - ("py:class", "CtmType"), ("py:class", "ReversedMappingType"), ("py:class", "WordsType"), ("py:class", "MappingType"), ("py:class", "TextIO"), ("py:class", "SegmentationType"), ("py:class", "CtmErrorDict"), + ("py:class", "kwargs"), ("py:class", "Labels"), ("py:class", "ScpType"), ("py:class", "multiprocessing.Value"), diff --git a/docs/source/external_links.py b/docs/source/external_links.py index 4f3f621f..e7f56217 100644 --- a/docs/source/external_links.py +++ b/docs/source/external_links.py @@ -17,7 +17,7 @@ :license: BSD, see LICENSE for details. """ -from typing import Any, Dict, List, Tuple +from typing import Any import sphinx from docutils import nodes, utils @@ -41,9 +41,9 @@ def model_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: text = utils.unescape(text) model_type, model_name = text.split("/") full_url = f"https://github.com/MontrealCorpusTools/mfa-models/raw/main/{model_type}/{model_name.lower()}.zip" @@ -58,9 +58,9 @@ def kaldi_steps_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: text = utils.unescape(text) full_url = f"https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5/steps/{text}.sh" title = f"{text}.sh" @@ -74,9 +74,9 @@ def kaldi_utils_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: filename = utils.unescape(text) full_url = f"https://github.com/kaldi-asr/kaldi/tree/master/egs/wsj/s5/utils/{filename}" title = f"{text}" @@ -90,9 +90,9 @@ def kaldi_steps_sid_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: text = utils.unescape(text) full_url = f"https://github.com/kaldi-asr/kaldi/tree/cbed4ff688a172a7f765493d24771c1bd57dcd20/egs/sre08/v1/sid/{text}.sh" title = f"sid/{text}.sh" @@ -106,9 +106,9 @@ def kaldi_docs_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: text = utils.unescape(text) t = text.split("#") text = t[0] @@ -129,9 +129,9 @@ def openfst_src_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: text = utils.unescape(text) full_url = f"https://www.openfst.org/doxygen/fst/html/{text}-main_8cc_source.html" title = f"OpenFst {text} source" @@ -145,9 +145,9 @@ def kaldi_src_role( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: text = utils.unescape(text) mapping = { "bin": set( @@ -378,9 +378,9 @@ def xref( text: str, lineno: int, inliner: Inliner, - options: Dict = None, - content: List[str] = None, -) -> Tuple[List[Node], List[system_message]]: + options: dict = None, + content: list[str] = None, +) -> tuple[list[Node], list[system_message]]: title = target = text # look if explicit title and target are given with `foo ` syntax @@ -409,7 +409,7 @@ def get_refs(app): xref.links = app.config.xref_links -def setup(app: Sphinx) -> Dict[str, Any]: +def setup(app: Sphinx) -> dict[str, Any]: app.add_config_value("xref_links", {}, "env") app.add_role("mfa_model", model_role) app.add_role("kaldi_steps", kaldi_steps_role) diff --git a/docs/source/first_steps/index.rst b/docs/source/first_steps/index.rst index e9953051..2f4fbf70 100644 --- a/docs/source/first_steps/index.rst +++ b/docs/source/first_steps/index.rst @@ -17,7 +17,7 @@ There are several broad use cases that you might want to use MFA for. Take a lo #. **Use case 1:** You have a speech corpus, the language involved is in the list of :ref:`pretrained_acoustic_models` and the list of :ref:`pretrained_dictionaries`. - #. Follow :ref:`first_steps_align_pretrained` to generate aligned TextGrids + #. Follow :ref:`first_steps_align_pretrained` to generate aligned TextGrids #. **Use case 2:** You have a speech corpus, the language involved is in the list of :ref:`pretrained_acoustic_models` and the list of :ref:`pretrained_g2p`, but not on the list of :ref:`pretrained_dictionaries`. diff --git a/docs/source/installation.rst b/docs/source/installation.rst index bfe2ea86..30c2649b 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -34,6 +34,26 @@ In general, it's recommend to create a new environment. If you want to update, Windows native install is not fully supported in 2.0. G2P functionality will be unavailable due to Pynini supporting only Linux and MacOS. To use G2P functionality on Windows, please set up the :xref:`wsl` and use the Bash console to continue the instructions. +Installing from source +====================== + +If the Conda installation above does not work or the binaries don't work on your system, you can try building Kaldi and OpenFst from source, along with MFA. + +1. Download/clone the :xref:`kaldi_github` and follow the installation instructions +2. If you're on Mac or Linux and want G2P functionality, install :xref:`openfst`, :xref:`opengrm_ngram`, :xref:`baumwelch`, and :xref:`pynini` +3. Make sure all Kaldi and other third party executables are on the system path +4. Download/clone the :xref:`mfa_github` and install MFA via :code:`python setup install` or :code:`pip install -e .` +5. Double check everything's working on the console with :code:`mfa -h` + +.. note:: + + You can also clone the conda forge feedstocks for `OpenFst `_, `SoX `_, `Kaldi `_, and `MFA `_ and run them with `conda build `_ to build for your specific system. + +MFA temporary files +=================== + +MFA uses a temporary directory for commands that can be specified in running commands with ``--temp_directory`` (or see :ref:`configuration`), and it also uses a directory to store global configuration settings and saved models. By default this root directory is ``~/Documents/MFA``, but if you would like to put this somewhere else, you can set the environment variable ``MFA_ROOT_DIR`` to use that. MFA will raise an error on load if it's unable to write the specified root directory. + Supported functionality ======================= diff --git a/docs/source/reference/abc.rst b/docs/source/reference/abc.rst deleted file mode 100644 index 8a279bd3..00000000 --- a/docs/source/reference/abc.rst +++ /dev/null @@ -1,13 +0,0 @@ -.. automodule:: montreal_forced_aligner.abc - - .. autosummary:: - :toctree: generated/ - - MfaModel -- Base model type for MFA - MfaWorker -- Base worker class for MFA - AcousticModelWorker -- MFA workers that have acoustic models - Aligner -- Aligner type interface - Dictionary -- Dictionary type interface - IvectorExtractor -- Ivector extractor type interface - Trainer -- Trainer type interface - Transcriber -- Transcriber type interface diff --git a/docs/source/reference/acoustic_modeling/helper.rst b/docs/source/reference/acoustic_modeling/helper.rst new file mode 100644 index 00000000..b731b961 --- /dev/null +++ b/docs/source/reference/acoustic_modeling/helper.rst @@ -0,0 +1,104 @@ + +Helper functionality +==================== + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.base + +.. autosummary:: + :toctree: generated/ + + AcousticModelTrainingMixin -- Basic mixin + + +Multiprocessing workers and functions +------------------------------------- + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.base + +.. autosummary:: + :toctree: generated/ + + acc_stats_func + compute_alignment_improvement_func + compare_alignments + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.monophone + +.. autosummary:: + :toctree: generated/ + + mono_align_equal_func + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.triphone + +.. autosummary:: + :toctree: generated/ + + tree_stats_func + convert_alignments_func + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.lda + +.. autosummary:: + :toctree: generated/ + + lda_acc_stats_func + calc_lda_mllt_func + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.sat + +.. autosummary:: + :toctree: generated/ + + acc_stats_two_feats_func + +Multiprocessing argument classes +-------------------------------- + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.base + +.. autosummary:: + :toctree: generated/ + + AccStatsArguments + AlignmentImprovementArguments + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.monophone + +.. autosummary:: + :toctree: generated/ + + MonoAlignEqualArguments + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.triphone + +.. autosummary:: + :toctree: generated/ + + TreeStatsArguments + ConvertAlignmentsArguments + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.lda + +.. autosummary:: + :toctree: generated/ + + LdaAccStatsArguments + CalcLdaMlltArguments + + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.sat + +.. autosummary:: + :toctree: generated/ + + AccStatsTwoFeatsArguments diff --git a/docs/source/reference/acoustic_modeling/index.rst b/docs/source/reference/acoustic_modeling/index.rst new file mode 100644 index 00000000..c06580dd --- /dev/null +++ b/docs/source/reference/acoustic_modeling/index.rst @@ -0,0 +1,24 @@ + +.. _acoustic_modeling_api: + +Acoustic models +=============== + +:term:`Acoustic models` contain information about how phones are pronounced, trained over large (and not-so-large) corpora of speech. Currently only GMM-HMM style acoustic models are supported, which are generally good enough for alignment, but nowhere near state of the art for transcription. + +.. note:: + + As part of the training procedure, alignments are generated, and so can be exported at the end (the same as training an acoustic model and then using it with the :class:`~montreal_forced_aligner.alignment.pretrained.PretrainedAligner`. See :meth:`~montreal_forced_aligner.alignment.CorpusAligner.export_files` for the method and :func:`~montreal_forced_aligner.command_line.run_train_acoustic_model` for the command line function. + +.. currentmodule:: montreal_forced_aligner.models + +.. autosummary:: + :toctree: generated/ + + AcousticModel + +.. toctree:: + :hidden: + + training + helper diff --git a/docs/source/reference/acoustic_modeling/training.rst b/docs/source/reference/acoustic_modeling/training.rst new file mode 100644 index 00000000..3f8e4893 --- /dev/null +++ b/docs/source/reference/acoustic_modeling/training.rst @@ -0,0 +1,22 @@ + +.. _acoustic_model_training_api: + +Training acoustic models +======================== + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling.trainer + +.. autosummary:: + :toctree: generated/ + + TrainableAligner + +.. currentmodule:: montreal_forced_aligner.acoustic_modeling + +.. autosummary:: + :toctree: generated/ + + MonophoneTrainer -- Monophone trainer + TriphoneTrainer -- Triphone trainer + LdaTrainer -- LDA trainer + SatTrainer -- Speaker adapted trainer diff --git a/docs/source/reference/aligner.rst b/docs/source/reference/aligner.rst deleted file mode 100644 index 0334b59f..00000000 --- a/docs/source/reference/aligner.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. automodule:: montreal_forced_aligner.aligner - - .. autosummary:: - :toctree: generated/ - - BaseAligner -- Base aligner - AdaptingAligner -- Adapting aligner - PretrainedAligner -- Pretrained aligner - TrainableAligner -- Trainable aligner diff --git a/docs/source/reference/alignment/alignment.rst b/docs/source/reference/alignment/alignment.rst new file mode 100644 index 00000000..a0d9038d --- /dev/null +++ b/docs/source/reference/alignment/alignment.rst @@ -0,0 +1,14 @@ + +.. _aligners_api: + +Alignment classes +================= + +.. currentmodule:: montreal_forced_aligner.alignment + +.. autosummary:: + :toctree: generated/ + + CorpusAligner -- Base aligner + AdaptingAligner -- Adapting an acoustic model to new data + PretrainedAligner -- Pretrained aligner diff --git a/docs/source/reference/alignment/helper.rst b/docs/source/reference/alignment/helper.rst new file mode 100644 index 00000000..a2b8cdce --- /dev/null +++ b/docs/source/reference/alignment/helper.rst @@ -0,0 +1,74 @@ + +Helper functionality +==================== + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.alignment.mixins + +.. autosummary:: + :toctree: generated/ + + AlignMixin -- Alignment mixin + +Multiprocessing workers and functions +------------------------------------- + +.. currentmodule:: montreal_forced_aligner.alignment.adapting + +.. autosummary:: + :toctree: generated/ + + map_acc_stats_func + +.. currentmodule:: montreal_forced_aligner.alignment.multiprocessing + +.. autosummary:: + :toctree: generated/ + + align_func + compile_train_graphs_func + compile_information_func + ali_to_ctm_func + PhoneCtmProcessWorker + CleanupWordCtmProcessWorker + NoCleanupWordCtmProcessWorker + CombineProcessWorker + ExportPreparationProcessWorker + ExportTextGridProcessWorker + + +Multiprocessing argument classes +-------------------------------- + +.. currentmodule:: montreal_forced_aligner.alignment.adapting + +.. autosummary:: + :toctree: generated/ + + MapAccStatsArguments + +.. currentmodule:: montreal_forced_aligner.alignment.multiprocessing + +.. autosummary:: + :toctree: generated/ + + AlignArguments + compile_train_graphs_func + CompileTrainGraphsArguments + compile_information_func + CompileInformationArguments + ali_to_ctm_func + AliToCtmArguments + PhoneCtmProcessWorker + PhoneCtmArguments + CleanupWordCtmProcessWorker + CleanupWordCtmArguments + NoCleanupWordCtmProcessWorker + NoCleanupWordCtmArguments + CombineProcessWorker + CombineCtmArguments + ExportPreparationProcessWorker + ExportTextGridProcessWorker + ExportTextGridArguments diff --git a/docs/source/reference/alignment/index.rst b/docs/source/reference/alignment/index.rst new file mode 100644 index 00000000..df72d628 --- /dev/null +++ b/docs/source/reference/alignment/index.rst @@ -0,0 +1,10 @@ + +.. _alignment_api: + +Alignment +========= + +.. toctree:: + + alignment + helper diff --git a/docs/source/reference/base_index.rst b/docs/source/reference/base_index.rst deleted file mode 100644 index 1df4819d..00000000 --- a/docs/source/reference/base_index.rst +++ /dev/null @@ -1,9 +0,0 @@ - -Base classes -============ - -.. toctree:: - - corpus - dictionary - models diff --git a/docs/source/reference/command_line.rst b/docs/source/reference/command_line.rst deleted file mode 100644 index 07ab83df..00000000 --- a/docs/source/reference/command_line.rst +++ /dev/null @@ -1,29 +0,0 @@ -Command line functions -====================== - -.. automodule:: montreal_forced_aligner.command_line - - .. autosummary:: - :toctree: generated/ - - main - create_parser - validate_model_arg - run_transcribe_corpus - run_validate_corpus - run_train_lm - run_train_g2p - run_align_corpus - run_train_dictionary - run_anchor - run_adapt_model - run_train_acoustic_model - run_train_ivector_extractor - run_g2p - run_create_segments - run_classify_speakers - run_model - list_model - save_model - inspect_model - download_model diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst deleted file mode 100644 index c5f1f9ad..00000000 --- a/docs/source/reference/config.rst +++ /dev/null @@ -1,17 +0,0 @@ -.. automodule:: montreal_forced_aligner.config - - .. autosummary:: - :toctree: generated/ - - BaseConfig -- Base configuration - AlignConfig -- Alignment configuration - DictionaryConfig -- Dictionary configuration - CommandConfig -- Command configuration - FeatureConfig -- Feature configuration - SegmentationConfig -- Segmentation configuration - SpeakerClassificationConfig -- Speaker classification configuration - TrainingConfig -- Training configuration - TrainLMConfig -- Training language model configuration - TranscribeConfig -- Transcription configuration - TrainG2PConfig -- Train G2P model configuration - G2PConfig -- G2P configuration diff --git a/docs/source/reference/core_index.rst b/docs/source/reference/core_index.rst new file mode 100644 index 00000000..d5e787af --- /dev/null +++ b/docs/source/reference/core_index.rst @@ -0,0 +1,15 @@ + +Core functionality +================== + +This sections contains the core objects that are used as input to any top level worker: the corpora, pronunciation dictionaries, and various types of MFA models. Each model's section contains the classes and functionality used to train them. + +.. toctree:: + :maxdepth: 1 + + corpus/index + dictionary/index + acoustic_modeling/index + g2p_modeling/index + language_modeling/index + ivector/index diff --git a/docs/source/reference/corpus.rst b/docs/source/reference/corpus.rst deleted file mode 100644 index 20fee2ff..00000000 --- a/docs/source/reference/corpus.rst +++ /dev/null @@ -1,10 +0,0 @@ - -.. automodule:: montreal_forced_aligner.corpus - - .. autosummary:: - :toctree: generated/ - - Corpus -- Class for defining corpora in MFA - Speaker -- Class for collecting metadata about speakers in corpora - File -- Class for representing sound file/transcription file pairs in corpora - Utterance -- Class for collecting information about utterances diff --git a/docs/source/reference/corpus/index.rst b/docs/source/reference/corpus/index.rst new file mode 100644 index 00000000..bad9ed3a --- /dev/null +++ b/docs/source/reference/corpus/index.rst @@ -0,0 +1,105 @@ + +.. _corpus_api: + +Corpora +======= + +.. currentmodule:: montreal_forced_aligner.corpus.acoustic_corpus + +.. autosummary:: + :toctree: generated/ + + AcousticCorpus + +.. currentmodule:: montreal_forced_aligner.corpus.text_corpus + +.. autosummary:: + :toctree: generated/ + + TextCorpus + +.. currentmodule:: montreal_forced_aligner.corpus.classes + +.. autosummary:: + :toctree: generated/ + + Speaker -- Class for collecting metadata about speakers in corpora + File -- Class for representing sound file/transcription file pairs in corpora + Utterance -- Class for collecting information about utterances + +Helper classes and functions +============================ + +Multiprocessing +--------------- + +.. currentmodule:: montreal_forced_aligner.corpus.multiprocessing + +.. autosummary:: + :toctree: generated/ + + Job + CorpusProcessWorker + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.corpus.base + +.. autosummary:: + :toctree: generated/ + + CorpusMixin + +.. currentmodule:: montreal_forced_aligner.corpus.acoustic_corpus + +.. autosummary:: + :toctree: generated/ + + AcousticCorpusMixin + AcousticCorpusPronunciationMixin + +.. currentmodule:: montreal_forced_aligner.corpus.ivector_corpus + +.. autosummary:: + :toctree: generated/ + + IvectorCorpusMixin + +.. currentmodule:: montreal_forced_aligner.corpus.text_corpus + +.. autosummary:: + :toctree: generated/ + + TextCorpusMixin + DictionaryTextCorpusMixin + +Features +-------- + +.. currentmodule:: montreal_forced_aligner.corpus.features + +.. autosummary:: + :toctree: generated/ + + FeatureConfigMixin + mfcc_func + MfccArguments + calc_fmllr_func + CalcFmllrArguments + IvectorConfigMixin + VadConfigMixin + compute_vad_func + VadArguments + VadArguments + +Ivector +------- + +.. currentmodule:: montreal_forced_aligner.corpus.features + +.. autosummary:: + :toctree: generated/ + + extract_ivectors_func + ExtractIvectorsArguments diff --git a/docs/source/reference/dictionary.rst b/docs/source/reference/dictionary.rst deleted file mode 100644 index f6f5e3e7..00000000 --- a/docs/source/reference/dictionary.rst +++ /dev/null @@ -1,9 +0,0 @@ - -.. automodule:: montreal_forced_aligner.dictionary - - .. autosummary:: - :toctree: generated/ - - PronunciationDictionary -- Pronunciation dictionary for Kaldi - MultispeakerDictionary -- Collection of pronunciation dictionaries that specify speaker-dictionary mappings - DictionaryData -- Data class generated by PronunciationDictionary to parse to and from Kaldi-internal strings diff --git a/docs/source/reference/dictionary/helper.rst b/docs/source/reference/dictionary/helper.rst new file mode 100644 index 00000000..fb74c49a --- /dev/null +++ b/docs/source/reference/dictionary/helper.rst @@ -0,0 +1,68 @@ + +Helper classes and functions +============================ + +Model +----- + +.. currentmodule:: montreal_forced_aligner.models + +.. autosummary:: + :toctree: generated/ + + DictionaryModel + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.dictionary.mixins + +.. autosummary:: + :toctree: generated/ + + DictionaryMixin + +.. currentmodule:: montreal_forced_aligner.dictionary.base + +.. autosummary:: + :toctree: generated/ + + PronunciationDictionaryMixin + +.. currentmodule:: montreal_forced_aligner.dictionary.multispeaker + +.. autosummary:: + :toctree: generated/ + + MultispeakerDictionaryMixin + +Helper +------ + +.. currentmodule:: montreal_forced_aligner.dictionary + +.. autosummary:: + :toctree: generated/ + + DictionaryData -- Data class generated by PronunciationDictionary to parse to and from Kaldi-internal strings + +.. currentmodule:: montreal_forced_aligner.dictionary.mixins + +.. autosummary:: + :toctree: generated/ + + SanitizeFunction + +Pronunciation probability functionality +======================================= + +Helper +------ + +.. currentmodule:: montreal_forced_aligner.alignment.pretrained + +.. autosummary:: + :toctree: generated/ + + generate_pronunciations_func + GeneratePronunciationsArguments diff --git a/docs/source/reference/dictionary/index.rst b/docs/source/reference/dictionary/index.rst new file mode 100644 index 00000000..d4c872dc --- /dev/null +++ b/docs/source/reference/dictionary/index.rst @@ -0,0 +1,11 @@ + +.. _dictionary_training_api: + +Pronunciation dictionaries +========================== + +.. toctree:: + + main + helper + training diff --git a/docs/source/reference/dictionary/main.rst b/docs/source/reference/dictionary/main.rst new file mode 100644 index 00000000..8eb93c41 --- /dev/null +++ b/docs/source/reference/dictionary/main.rst @@ -0,0 +1,11 @@ + +Main classes +============ + +.. currentmodule:: montreal_forced_aligner.dictionary + +.. autosummary:: + :toctree: generated/ + + PronunciationDictionary -- Pronunciation dictionary for Kaldi + MultispeakerDictionary -- Collection of pronunciation dictionaries that specify speaker-dictionary mappings diff --git a/docs/source/reference/dictionary/training.rst b/docs/source/reference/dictionary/training.rst new file mode 100644 index 00000000..7d76b1c4 --- /dev/null +++ b/docs/source/reference/dictionary/training.rst @@ -0,0 +1,10 @@ + +Training pronunciation probabilities +==================================== + +.. currentmodule:: montreal_forced_aligner.alignment.pretrained + +.. autosummary:: + :toctree: generated/ + + DictionaryTrainer -- Train pronunciation probabilities from alignments diff --git a/docs/source/reference/g2p.rst b/docs/source/reference/g2p.rst deleted file mode 100644 index ddaf274a..00000000 --- a/docs/source/reference/g2p.rst +++ /dev/null @@ -1,7 +0,0 @@ -.. automodule:: montreal_forced_aligner.g2p - - .. autosummary:: - :toctree: generated/ - - PyniniTrainer -- Trainer for Pynini G2P model - PyniniDictionaryGenerator -- Generator for Pynini G2P model diff --git a/docs/source/reference/g2p/generator.rst b/docs/source/reference/g2p/generator.rst new file mode 100644 index 00000000..1b5b71de --- /dev/null +++ b/docs/source/reference/g2p/generator.rst @@ -0,0 +1,13 @@ + +.. _generating_dictionaries_api: + +Dictionary generation +===================== + +.. currentmodule:: montreal_forced_aligner.g2p.generator + +.. autosummary:: + :toctree: generated/ + + PyniniCorpusGenerator -- Generator for Pynini G2P model + PyniniWordListGenerator -- Generator for Pynini G2P model diff --git a/docs/source/reference/g2p/helper.rst b/docs/source/reference/g2p/helper.rst new file mode 100644 index 00000000..0731307c --- /dev/null +++ b/docs/source/reference/g2p/helper.rst @@ -0,0 +1,24 @@ + +Helper functionality +==================== + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.g2p.generator + +.. autosummary:: + :toctree: generated/ + + PyniniGenerator + +Helper +------ + +.. currentmodule:: montreal_forced_aligner.g2p.generator + +.. autosummary:: + :toctree: generated/ + + Rewriter + RewriterWorker diff --git a/docs/source/reference/g2p/index.rst b/docs/source/reference/g2p/index.rst new file mode 100644 index 00000000..d37573a4 --- /dev/null +++ b/docs/source/reference/g2p/index.rst @@ -0,0 +1,10 @@ + +.. _g2p_generate_api: + +Generating dictionaries +======================= + +.. toctree:: + + generator + helper diff --git a/docs/source/reference/g2p_modeling/helper.rst b/docs/source/reference/g2p_modeling/helper.rst new file mode 100644 index 00000000..52740250 --- /dev/null +++ b/docs/source/reference/g2p_modeling/helper.rst @@ -0,0 +1,34 @@ + +Helper functionality +==================== + + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.g2p.mixins + +.. autosummary:: + :toctree: generated/ + + G2PMixin + G2PTopLevelMixin + +.. currentmodule:: montreal_forced_aligner.g2p.trainer + +.. autosummary:: + :toctree: generated/ + + G2PTrainer + +Helper +------ + +.. currentmodule:: montreal_forced_aligner.g2p.trainer + +.. autosummary:: + :toctree: generated/ + + PairNGramAligner + RandomStartWorker + RandomStart diff --git a/docs/source/reference/g2p_modeling/index.rst b/docs/source/reference/g2p_modeling/index.rst new file mode 100644 index 00000000..8d069ba0 --- /dev/null +++ b/docs/source/reference/g2p_modeling/index.rst @@ -0,0 +1,19 @@ + +.. _g2p_modeling_api: + +Grapheme-to-Phoneme (G2P) models +================================ + +G2P models are used to generate pronunciations from orthographic spellings. The G2P models currently supported use Pynini weighted finite state transducers (wFST) to based off a training lexicon. + +.. currentmodule:: montreal_forced_aligner.models + +.. autosummary:: + :toctree: generated/ + + G2PModel + +.. toctree:: + + training + helper diff --git a/docs/source/reference/g2p_modeling/training.rst b/docs/source/reference/g2p_modeling/training.rst new file mode 100644 index 00000000..70e67566 --- /dev/null +++ b/docs/source/reference/g2p_modeling/training.rst @@ -0,0 +1,10 @@ +Training G2P models +=================== + +.. currentmodule:: montreal_forced_aligner.g2p.trainer + +.. autosummary:: + :toctree: generated/ + + PyniniTrainer -- Trainer for Pynini G2P model + PyniniValidator -- Validator for Pynini G2P model\ diff --git a/docs/source/reference/helper/abc.rst b/docs/source/reference/helper/abc.rst new file mode 100644 index 00000000..eb5a2990 --- /dev/null +++ b/docs/source/reference/helper/abc.rst @@ -0,0 +1,20 @@ +.. automodule:: montreal_forced_aligner.abc + + .. autosummary:: + :toctree: generated/ + + MfaModel -- Base model type for MFA + MfaWorker -- Base worker class for MFA + TopLevelMfaWorker -- MFA workers that have acoustic models + TrainerMixin -- Trainer type interface + TemporaryDirectoryMixin -- Trainer type interface + AdapterMixin -- Trainer type interface + FileExporterMixin -- File exporter type interface + ModelExporterMixin -- Model exporter type interface + +.. automodule:: montreal_forced_aligner.models + + .. autosummary:: + :toctree: generated/ + + Archive diff --git a/docs/source/reference/helper/command_line.rst b/docs/source/reference/helper/command_line.rst new file mode 100644 index 00000000..1f9e6032 --- /dev/null +++ b/docs/source/reference/helper/command_line.rst @@ -0,0 +1,29 @@ +Command line functions +====================== + +.. currentmodule:: montreal_forced_aligner.command_line + +.. autosummary:: + :toctree: generated/ + + main + create_parser + validate_model_arg + run_transcribe_corpus + run_validate_corpus + run_train_lm + run_train_g2p + run_align_corpus + run_train_dictionary + run_anchor + run_adapt_model + run_train_acoustic_model + run_train_ivector_extractor + run_g2p + run_create_segments + run_classify_speakers + run_model + list_model + save_model + inspect_model + download_model diff --git a/docs/source/reference/helper/config.rst b/docs/source/reference/helper/config.rst new file mode 100644 index 00000000..a2870265 --- /dev/null +++ b/docs/source/reference/helper/config.rst @@ -0,0 +1,11 @@ +.. automodule:: montreal_forced_aligner.config + + .. autosummary:: + :toctree: generated/ + + generate_config_path + generate_command_history_path + load_command_history + update_command_history + update_global_config + load_global_config diff --git a/docs/source/reference/data.rst b/docs/source/reference/helper/data.rst similarity index 100% rename from docs/source/reference/data.rst rename to docs/source/reference/helper/data.rst diff --git a/docs/source/reference/exceptions.rst b/docs/source/reference/helper/exceptions.rst similarity index 100% rename from docs/source/reference/exceptions.rst rename to docs/source/reference/helper/exceptions.rst diff --git a/docs/source/reference/helper.rst b/docs/source/reference/helper/helper.rst similarity index 81% rename from docs/source/reference/helper.rst rename to docs/source/reference/helper/helper.rst index dd9269c6..ffb26f96 100644 --- a/docs/source/reference/helper.rst +++ b/docs/source/reference/helper/helper.rst @@ -13,3 +13,6 @@ score edit_distance output_mapping + compare_labels + overlap_scoring + align_phones diff --git a/docs/source/reference/helper_index.rst b/docs/source/reference/helper/index.rst similarity index 82% rename from docs/source/reference/helper_index.rst rename to docs/source/reference/helper/index.rst index 52e42f30..55fe2abc 100644 --- a/docs/source/reference/helper_index.rst +++ b/docs/source/reference/helper/index.rst @@ -1,3 +1,6 @@ + +.. _helper_api: + Helper ====== @@ -9,6 +12,5 @@ Helper data exceptions helper - multiprocessing/index textgrid utils diff --git a/docs/source/reference/textgrid.rst b/docs/source/reference/helper/textgrid.rst similarity index 90% rename from docs/source/reference/textgrid.rst rename to docs/source/reference/helper/textgrid.rst index a924c0e6..90524353 100644 --- a/docs/source/reference/textgrid.rst +++ b/docs/source/reference/helper/textgrid.rst @@ -11,4 +11,3 @@ export_textgrid ctm_to_textgrid output_textgrid_writing_errors - ctms_to_textgrids_non_mp diff --git a/docs/source/reference/helper/utils.rst b/docs/source/reference/helper/utils.rst new file mode 100644 index 00000000..e34f5946 --- /dev/null +++ b/docs/source/reference/helper/utils.rst @@ -0,0 +1,15 @@ +.. automodule:: montreal_forced_aligner.utils + + .. autosummary:: + :toctree: generated/ + + Counter + Stopped + ProcessWorker + run_mp + run_non_mp + thirdparty_binary + log_kaldi_errors + guess_model_type + parse_logs + CustomFormatter diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 2364d7e7..71f7d8f5 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -5,14 +5,22 @@ MFA API .. warning:: - While the MFA API is fairly stable, I do tend to do refactors on fairly regular basis. As 2.0 gets more stable, these are likely to get smaller and smaller, and I will try to keep the API docs as up-to-date as possible, so if something breaks in any scripts depending on MFA, please check back here. + While the MFA command-line interface is fairly stable, I do tend to do refactors of the internal code on fairly regular basis. As 2.0 gets more stable, these are likely to get smaller and smaller, and I will try to keep the API docs as up-to-date as possible, so if something breaks in any scripts depending on MFA, please check back here. -API definition --------------- +Current structure +----------------- + +Prior to 2.0.0b8, MFA classes were fairly monolithic. There was a ``Corpus`` class that did everything related to corpus loading and processing text and sound files. However, the default acoustic model with sound files for alignment does not necessarily lend itself to language modeling for instance, and so there were several flags for text-only behavior that didn't feel satisfying. + +A bigger concern was as more configuration options were added to for processing pronunciation dictionaries, they would have to be duplicated in the existing workflow configuration objects (AlignConfig, TranscribeConfig, etc), or a new DictionaryConfig object that gets passed to all workflow classes (PretrainedAligner, Transcriber, etc) and data processing classes (Corpus, Dictionary). + +The current design mixes in functionality as necessary with abstract classes. So there is a :class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin` class that covers the functionality around what counts as a word, how to parse text through stripping punctuation, using compound and clitic markers in looking up words, etc. There are several :class:`~montreal_forced_aligner.corpus.base.CorpusMixin` classes that have similar data structure and attributes, but different load functionality for corpora with sound files (:class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusMixin` and :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin`, depending on whether a pronunciation dictionary is needed for processing) versus text-only corpora (:class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin` and :class:`~montreal_forced_aligner.corpus.text_corpus.DictionaryTextCorpusMixin`). + +This should (hopefully) make it easier to extend MFA for your own purposes if you so choose, and will certainly make it easier for me to implement new functionality going forward. .. toctree:: - :maxdepth: 2 + :hidden: - base_index - workers_index - helper_index + core_index + top_level_index + helper/index diff --git a/docs/source/reference/ivector/helper.rst b/docs/source/reference/ivector/helper.rst new file mode 100644 index 00000000..44b5762e --- /dev/null +++ b/docs/source/reference/ivector/helper.rst @@ -0,0 +1,29 @@ +Training functionality +====================== + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.ivector.trainer + +.. autosummary:: + :toctree: generated/ + + IvectorModelTrainingMixin + +Helper +------ + +.. currentmodule:: montreal_forced_aligner.ivector.trainer + +.. autosummary:: + :toctree: generated/ + + gmm_gselect_func + GmmGselectArguments + gauss_to_post_func + GaussToPostArguments + acc_global_stats_func + AccGlobalStatsArguments + acc_ivector_stats_func + AccIvectorStatsArguments diff --git a/docs/source/reference/ivector/index.rst b/docs/source/reference/ivector/index.rst new file mode 100644 index 00000000..3b2ae478 --- /dev/null +++ b/docs/source/reference/ivector/index.rst @@ -0,0 +1,21 @@ + +.. _ivector_api: + +Ivector extraction +================== + +.. warning:: + + This feature is not fully implemented, and is still under construction. + +.. currentmodule:: montreal_forced_aligner.models + +.. autosummary:: + :toctree: generated/ + + IvectorExtractorModel + +.. toctree:: + + training + helper diff --git a/docs/source/reference/ivector/training.rst b/docs/source/reference/ivector/training.rst new file mode 100644 index 00000000..bf3a7a16 --- /dev/null +++ b/docs/source/reference/ivector/training.rst @@ -0,0 +1,18 @@ + +.. _training_ivector_api: + +Training ivector extractors +=========================== + +.. warning:: + + This feature is not fully implemented, and is still under construction. + +.. currentmodule:: montreal_forced_aligner.ivector.trainer + +.. autosummary:: + :toctree: generated/ + + IvectorTrainer -- Training ivector extractor models + DubmTrainer -- Training block for DUBM + TrainableIvectorExtractor -- Top level worker for running Ivector training pipelines diff --git a/docs/source/reference/language_modeling/helper.rst b/docs/source/reference/language_modeling/helper.rst new file mode 100644 index 00000000..d707a63c --- /dev/null +++ b/docs/source/reference/language_modeling/helper.rst @@ -0,0 +1,9 @@ +Helper functionality +==================== + +.. currentmodule:: montreal_forced_aligner.language_modeling.trainer + +.. autosummary:: + :toctree: generated/ + + LmTrainerMixin -- Mixin for language model training diff --git a/docs/source/reference/language_modeling/index.rst b/docs/source/reference/language_modeling/index.rst new file mode 100644 index 00000000..4b3abc2d --- /dev/null +++ b/docs/source/reference/language_modeling/index.rst @@ -0,0 +1,19 @@ + +.. _language_modeling_api: + +Language models +=============== + +Language models allow for transcription via Speech-to-Text when used alongside acoustic models and pronunciation dictionaries. + +.. currentmodule:: montreal_forced_aligner.models + +.. autosummary:: + :toctree: generated/ + + LanguageModel + +.. toctree:: + + training + helper diff --git a/docs/source/reference/language_modeling/training.rst b/docs/source/reference/language_modeling/training.rst new file mode 100644 index 00000000..813155f3 --- /dev/null +++ b/docs/source/reference/language_modeling/training.rst @@ -0,0 +1,14 @@ + +.. _language_model_training_api: + +Training language models +======================== + +.. currentmodule:: montreal_forced_aligner.language_modeling.trainer + +.. autosummary:: + :toctree: generated/ + + LmCorpusTrainer -- Trainer for language model on text corpora + LmDictionaryCorpusTrainer -- Trainer for language model on text corpora + LmArpaTrainer -- Trainer for MFA language model on arpa format language model diff --git a/docs/source/reference/lm.rst b/docs/source/reference/lm.rst deleted file mode 100644 index 49861751..00000000 --- a/docs/source/reference/lm.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. automodule:: montreal_forced_aligner.lm - - .. autosummary:: - :toctree: generated/ - - LmTrainer -- Trainer for language model diff --git a/docs/source/reference/models.rst b/docs/source/reference/models.rst deleted file mode 100644 index 9aa5e22c..00000000 --- a/docs/source/reference/models.rst +++ /dev/null @@ -1,12 +0,0 @@ - -.. automodule:: montreal_forced_aligner.models - - .. autosummary:: - :toctree: generated/ - - Archive - LanguageModel - AcousticModel - IvectorExtractorModel - DictionaryModel - G2PModel diff --git a/docs/source/reference/multiprocessing/alignment.rst b/docs/source/reference/multiprocessing/alignment.rst deleted file mode 100644 index 8be63d82..00000000 --- a/docs/source/reference/multiprocessing/alignment.rst +++ /dev/null @@ -1,107 +0,0 @@ -Alignment -========= - -Basic ------ - -.. currentmodule:: montreal_forced_aligner.multiprocessing.alignment - -.. autosummary:: - :toctree: generated/ - - acc_stats - acc_stats_func - align - align_func - mono_align_equal - mono_align_equal_func - tree_stats - tree_stats_func - compile_train_graphs - compile_train_graphs_func - convert_alignments - convert_alignments_func - -LDA training ------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.alignment - -.. autosummary:: - :toctree: generated/ - - calc_lda_mllt - calc_lda_mllt_func - lda_acc_stats - lda_acc_stats_func - -Speaker adapted models ----------------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.alignment - -.. autosummary:: - :toctree: generated/ - - calc_fmllr - calc_fmllr_func - create_align_model - acc_stats_two_feats_func - -Acoustic model adaptation -------------------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.alignment - -.. autosummary:: - :toctree: generated/ - - train_map - map_acc_stats_func - - -TextGrid Export ---------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.alignment - -.. autosummary:: - :toctree: generated/ - - ctms_to_textgrids_mp - convert_ali_to_textgrids - ali_to_ctm_func - PhoneCtmProcessWorker - CleanupWordCtmProcessWorker - NoCleanupWordCtmProcessWorker - CombineProcessWorker - ExportPreparationProcessWorker - ExportTextGridProcessWorker - -Pronunciation probabilities ---------------------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.pronunciations - -.. autosummary:: - :toctree: generated/ - - generate_pronunciations - generate_pronunciations_func - -Validation ----------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.alignment - -.. autosummary:: - :toctree: generated/ - - compile_information - compile_information_func - compute_alignment_improvement - compute_alignment_improvement_func - compare_alignments - parse_iteration_alignments - compile_utterance_train_graphs_func - test_utterances_func diff --git a/docs/source/reference/multiprocessing/corpus.rst b/docs/source/reference/multiprocessing/corpus.rst deleted file mode 100644 index 568567b4..00000000 --- a/docs/source/reference/multiprocessing/corpus.rst +++ /dev/null @@ -1,23 +0,0 @@ -Corpora -======= - -.. automodule:: montreal_forced_aligner.multiprocessing.corpus - - .. autosummary:: - :toctree: generated/ - - CorpusProcessWorker - -Features --------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.features - -.. autosummary:: - :toctree: generated/ - - mfcc - mfcc_func - calc_cmvn - compute_vad - compute_vad_func diff --git a/docs/source/reference/multiprocessing/helper.rst b/docs/source/reference/multiprocessing/helper.rst deleted file mode 100644 index 24ba9ce0..00000000 --- a/docs/source/reference/multiprocessing/helper.rst +++ /dev/null @@ -1,65 +0,0 @@ -Helper -====== - -Functions ---------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.helper - -.. autosummary:: - :toctree: generated/ - - Counter - Stopped - ProcessWorker - run_mp - run_non_mp - -Classes -------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.classes - -.. autosummary:: - :toctree: generated/ - - Job - AlignArguments - VadArguments - SegmentVadArguments - CreateHclgArguments - AccGlobalStatsArguments - AccStatsArguments - AccIvectorStatsArguments - AccStatsTwoFeatsArguments - AliToCtmArguments - MfccArguments - ScoreArguments - DecodeArguments - PhoneCtmArguments - CombineCtmArguments - CleanupWordCtmArguments - NoCleanupWordCtmArguments - LmRescoreArguments - AlignmentImprovementArguments - ConvertAlignmentsArguments - CalcFmllrArguments - CalcLdaMlltArguments - GmmGselectArguments - FinalFmllrArguments - LatGenFmllrArguments - FmllrRescoreArguments - TreeStatsArguments - LdaAccStatsArguments - MapAccStatsArguments - GaussToPostArguments - InitialFmllrArguments - ExtractIvectorsArguments - ExportTextGridArguments - CompileTrainGraphsArguments - CompileInformationArguments - CompileUtteranceTrainGraphsArguments - MonoAlignEqualArguments - TestUtterancesArguments - CarpaLmRescoreArguments - GeneratePronunciationsArguments diff --git a/docs/source/reference/multiprocessing/index.rst b/docs/source/reference/multiprocessing/index.rst deleted file mode 100644 index f24e7fd8..00000000 --- a/docs/source/reference/multiprocessing/index.rst +++ /dev/null @@ -1,10 +0,0 @@ -Multiprocessing helper functions -================================ - -.. toctree:: - - corpus - alignment - ivector - transcription - helper diff --git a/docs/source/reference/multiprocessing/ivector.rst b/docs/source/reference/multiprocessing/ivector.rst deleted file mode 100644 index c048750a..00000000 --- a/docs/source/reference/multiprocessing/ivector.rst +++ /dev/null @@ -1,35 +0,0 @@ -Ivector -======= - -.. automodule:: montreal_forced_aligner.multiprocessing.ivector - - .. autosummary:: - :toctree: generated/ - - gmm_gselect - gmm_gselect_func - gauss_to_post - gauss_to_post_func - acc_global_stats - acc_global_stats_func - acc_ivector_stats - acc_ivector_stats_func - extract_ivectors - extract_ivectors_func - segment_vad - segment_vad_func - get_initial_segmentation - merge_segments - -File segmentation ------------------ - -.. currentmodule:: montreal_forced_aligner.multiprocessing.ivector - -.. autosummary:: - :toctree: generated/ - - segment_vad - segment_vad_func - get_initial_segmentation - merge_segments diff --git a/docs/source/reference/multiprocessing/transcription.rst b/docs/source/reference/multiprocessing/transcription.rst deleted file mode 100644 index 5a9275ba..00000000 --- a/docs/source/reference/multiprocessing/transcription.rst +++ /dev/null @@ -1,47 +0,0 @@ -Transcription -============= - -Decoding graph --------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.transcription - -.. autosummary:: - :toctree: generated/ - - create_hclgs - create_hclg_func - compose_hclg - compose_clg - compose_lg - compose_g - compose_g_carpa - -Speaker-independent transcription ---------------------------------- - -.. currentmodule:: montreal_forced_aligner.multiprocessing.transcription - -.. autosummary:: - :toctree: generated/ - - transcribe - decode_func - lm_rescore_func - carpa_lm_rescore_func - score_transcriptions - score_func - -Speaker-adapted transcription ------------------------------ - -.. currentmodule:: montreal_forced_aligner.multiprocessing.transcription - -.. autosummary:: - :toctree: generated/ - - transcribe_fmllr - initial_fmllr_func - lat_gen_fmllr_func - fmllr_rescore_func - final_fmllr_est_func diff --git a/docs/source/reference/segmentation/helper.rst b/docs/source/reference/segmentation/helper.rst new file mode 100644 index 00000000..7e6ea8f2 --- /dev/null +++ b/docs/source/reference/segmentation/helper.rst @@ -0,0 +1,13 @@ + +Helper functions +================ + +.. currentmodule:: montreal_forced_aligner.segmenter + +.. autosummary:: + :toctree: generated/ + + segment_vad_func + SegmentVadArguments + get_initial_segmentation + merge_segments diff --git a/docs/source/reference/segmentation/index.rst b/docs/source/reference/segmentation/index.rst new file mode 100644 index 00000000..1600cfb9 --- /dev/null +++ b/docs/source/reference/segmentation/index.rst @@ -0,0 +1,16 @@ + +.. _segmentation_api: + +Segmentation +============ + +Segmentation aims to break long audio files into chunks of speech. + +.. note:: + + The current implementation of segmentation uses only Voice Activity Detection (VAD) features. There's been some work towards getting a full speaker diarization set up going with :ref:`training_ivector_api` but that's largely planned for 2.1. + +.. toctree:: + + main + helper diff --git a/docs/source/reference/segmentation/main.rst b/docs/source/reference/segmentation/main.rst new file mode 100644 index 00000000..e502e93c --- /dev/null +++ b/docs/source/reference/segmentation/main.rst @@ -0,0 +1,10 @@ + +Segmenter +========= + +.. currentmodule:: montreal_forced_aligner.segmenter + +.. autosummary:: + :toctree: generated/ + + Segmenter diff --git a/docs/source/reference/segmenter.rst b/docs/source/reference/segmenter.rst deleted file mode 100644 index be2a29a6..00000000 --- a/docs/source/reference/segmenter.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. automodule:: montreal_forced_aligner.segmenter - - .. autosummary:: - :toctree: generated/ - - Segmenter diff --git a/docs/source/reference/speaker_classifier.rst b/docs/source/reference/speaker_classifier.rst deleted file mode 100644 index 34a81d0e..00000000 --- a/docs/source/reference/speaker_classifier.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. automodule:: montreal_forced_aligner.speaker_classifier - - .. autosummary:: - :toctree: generated/ - - SpeakerClassifier diff --git a/docs/source/reference/top_level_index.rst b/docs/source/reference/top_level_index.rst new file mode 100644 index 00000000..468efb9f --- /dev/null +++ b/docs/source/reference/top_level_index.rst @@ -0,0 +1,10 @@ +Workflows +========= + +.. toctree:: + + alignment/index + validation/index + g2p/index + transcription/index + segmentation/index diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst deleted file mode 100644 index bce0e269..00000000 --- a/docs/source/reference/trainers.rst +++ /dev/null @@ -1,12 +0,0 @@ - -.. automodule:: montreal_forced_aligner.trainers - - .. autosummary:: - :toctree: generated/ - - BaseTrainer -- Base trainer - MonophoneTrainer -- Monophone trainer - TriphoneTrainer -- Triphone trainer - LdaTrainer -- LDA trainer - SatTrainer -- Speaker adapted trainer - IvectorExtractorTrainer -- Trainer for IvectorExtractor diff --git a/docs/source/reference/transcriber.rst b/docs/source/reference/transcriber.rst deleted file mode 100644 index 5b6d32ba..00000000 --- a/docs/source/reference/transcriber.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. automodule:: montreal_forced_aligner.transcriber - - .. autosummary:: - :toctree: generated/ - - Transcriber diff --git a/docs/source/reference/transcription/helper.rst b/docs/source/reference/transcription/helper.rst new file mode 100644 index 00000000..763dd271 --- /dev/null +++ b/docs/source/reference/transcription/helper.rst @@ -0,0 +1,63 @@ +Helper functions +================ + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.transcription.transcriber + +.. autosummary:: + :toctree: generated/ + + TranscriberMixin + +Decoding graph +-------------- + +.. currentmodule:: montreal_forced_aligner.transcription.multiprocessing + +.. autosummary:: + :toctree: generated/ + + create_hclg_func + CreateHclgArguments + compose_hclg + compose_clg + compose_lg + compose_g + compose_g_carpa + + +Speaker-independent transcription +--------------------------------- + +.. currentmodule:: montreal_forced_aligner.transcription.multiprocessing + +.. autosummary:: + :toctree: generated/ + + decode_func + DecodeArguments + lm_rescore_func + LmRescoreArguments + carpa_lm_rescore_func + CarpaLmRescoreArguments + score_func + ScoreArguments + +Speaker-adapted transcription +----------------------------- + +.. currentmodule:: montreal_forced_aligner.transcription.multiprocessing + +.. autosummary:: + :toctree: generated/ + + initial_fmllr_func + InitialFmllrArguments + lat_gen_fmllr_func + LatGenFmllrArguments + fmllr_rescore_func + FmllrRescoreArguments + final_fmllr_est_func + FinalFmllrArguments diff --git a/docs/source/reference/transcription/index.rst b/docs/source/reference/transcription/index.rst new file mode 100644 index 00000000..237d5426 --- /dev/null +++ b/docs/source/reference/transcription/index.rst @@ -0,0 +1,12 @@ + +.. _transcription_api: + +Transcription +============= + +MFA can use trained acoustic models (see :ref:`acoustic_model_training_api`), trained language models (see :ref:`language_model_training_api`), and pronunciation dictionaries (see :ref:`generating_dictionaries_api`) in order to generate transcripts for audio files. + +.. toctree:: + + main + helper diff --git a/docs/source/reference/transcription/main.rst b/docs/source/reference/transcription/main.rst new file mode 100644 index 00000000..bfc568e0 --- /dev/null +++ b/docs/source/reference/transcription/main.rst @@ -0,0 +1,9 @@ +Transcriber +=========== + +.. currentmodule:: montreal_forced_aligner.transcription + +.. autosummary:: + :toctree: generated/ + + Transcriber diff --git a/docs/source/reference/utils.rst b/docs/source/reference/utils.rst deleted file mode 100644 index 6b100776..00000000 --- a/docs/source/reference/utils.rst +++ /dev/null @@ -1,24 +0,0 @@ -.. automodule:: montreal_forced_aligner.utils - - .. autosummary:: - :toctree: generated/ - - thirdparty_binary - get_available_dictionaries - log_config - log_kaldi_errors - get_available_models - get_available_language_models - get_available_acoustic_models - get_available_g2p_models - get_pretrained_language_model_path - get_pretrained_g2p_path - get_pretrained_ivector_path - get_pretrained_path - get_pretrained_acoustic_path - get_dictionary_path - get_available_ivector_extractors - guess_model_type - parse_logs - setup_logger - CustomFormatter diff --git a/docs/source/reference/validation/helper.rst b/docs/source/reference/validation/helper.rst new file mode 100644 index 00000000..9bbbc81f --- /dev/null +++ b/docs/source/reference/validation/helper.rst @@ -0,0 +1,26 @@ +Helper functions +================ + +Mixins +------ + +.. currentmodule:: montreal_forced_aligner.validator + +.. autosummary:: + :toctree: generated/ + + ValidationMixin + + +Helper +------ + +.. currentmodule:: montreal_forced_aligner.validator + +.. autosummary:: + :toctree: generated/ + + compile_utterance_train_graphs_func + test_utterances_func + CompileUtteranceTrainGraphsArguments + TestUtterancesArguments diff --git a/docs/source/reference/validation/index.rst b/docs/source/reference/validation/index.rst new file mode 100644 index 00000000..8ab91c43 --- /dev/null +++ b/docs/source/reference/validation/index.rst @@ -0,0 +1,12 @@ + +.. _validation_api: + +Validation +========== + +The validation utilities are used to evaluate a dataset for either training an acoustic model, or performing alignment. They will detect issues with sound files, transcription files, unalignable utterances, and can perform some simplistic evaluation of transcripts. + +.. toctree:: + + main + helper diff --git a/docs/source/reference/validation/main.rst b/docs/source/reference/validation/main.rst new file mode 100644 index 00000000..49c59c22 --- /dev/null +++ b/docs/source/reference/validation/main.rst @@ -0,0 +1,10 @@ +Validators +========== + +.. currentmodule:: montreal_forced_aligner.validator + +.. autosummary:: + :toctree: generated/ + + TrainingValidator + PretrainedValidator diff --git a/docs/source/reference/validator.rst b/docs/source/reference/validator.rst deleted file mode 100644 index c0e951cd..00000000 --- a/docs/source/reference/validator.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. automodule:: montreal_forced_aligner.validator - - .. autosummary:: - :toctree: generated/ - - CorpusValidator diff --git a/docs/source/reference/workers_index.rst b/docs/source/reference/workers_index.rst deleted file mode 100644 index f0cf27a5..00000000 --- a/docs/source/reference/workers_index.rst +++ /dev/null @@ -1,13 +0,0 @@ -MFA workers -=========== - -.. toctree:: - - aligner - g2p - lm - segmenter - speaker_classifier - trainers - transcriber - validator diff --git a/docs/source/user_guide/commands.rst b/docs/source/user_guide/commands.rst index 1ec3c38e..a8df37d1 100644 --- a/docs/source/user_guide/commands.rst +++ b/docs/source/user_guide/commands.rst @@ -13,46 +13,46 @@ Preparation .. csv-table:: :header: "Command", "Description", "Link" - :widths: 10, 110, 40 + :widths: 50, 110, 40 - "validate", "Validate a corpus", :ref:`validating_data` + "``mfa validate``", "Validate a corpus", :ref:`validating_data` Forced Alignment ================ .. csv-table:: :header: "Command", "Description", "Link" - :widths: 10, 110, 40 + :widths: 50, 110, 40 - "align", "Perform forced alignment with a pretrained model", :ref:`pretrained_alignment` - "train", "Train an acoustic model and export resulting alignment", :ref:`train_acoustic_model` - "adapt", "Adapt a pretrained acoustic model on a new dataset", :ref:`adapt_acoustic_model` - "train_dictionary", "Estimate pronunciation probabilities from aligning a corpus", :ref:`training_dictionary` + "``mfa align``", "Perform forced alignment with a pretrained model", :ref:`pretrained_alignment` + "``mfa train``", "Train an acoustic model and export resulting alignment", :ref:`train_acoustic_model` + "``mfa adapt``", "Adapt a pretrained acoustic model on a new dataset", :ref:`adapt_acoustic_model` + "``mfa train_dictionary``", "Estimate pronunciation probabilities from aligning a corpus", :ref:`training_dictionary` Corpus creation =============== .. csv-table:: :header: "Command", "Description", "Link" - :widths: 10, 110, 40 + :widths: 50, 110, 40 - "create_segments", "Use voice activity detection to create segments", :ref:`create_segments` - "train_ivector", "Train an ivector extractor for speaker classification", :ref:`train_ivector` - "classify_speakers", "Use ivector extractor to classify files or cluster them", :ref:`classify_speakers` - "transcribe", "Generate transcriptions using an acoustic model, dictionary, and language model", :ref:`transcribing` - "train_lm", "Train a language model from a text corpus or from an existing language model", :ref:`training_lm` - "anchor", "Run the Anchor annotator utility (if installed) for editing and managing corpora", :ref:`anchor` + "``mfa create_segments``", "Use voice activity detection to create segments", :ref:`create_segments` + "``mfa train_ivector``", "Train an ivector extractor for speaker classification", :ref:`train_ivector` + "``mfa classify_speakers``", "Use ivector extractor to classify files or cluster them", :ref:`classify_speakers` + "``mfa transcribe``", "Generate transcriptions using an acoustic model, dictionary, and language model", :ref:`transcribing` + "``mfa train_lm``", "Train a language model from a text corpus or from an existing language model", :ref:`training_lm` + "``mfa anchor``", "Run the Anchor annotator utility (if installed) for editing and managing corpora", :ref:`anchor` Other utilities =============== .. csv-table:: :header: "Command", "Description", "Link" - :widths: 10, 110, 40 + :widths: 50, 110, 40 - "model", "Inspect/list/download/save models", :ref:`pretrained_models` - "configure", "Configure MFA to use customized defaults for command line arguments", :ref:`configuration` - "history", "List previous MFA commands run locally", + "``mfa model``", "Inspect/list/download/save models", :ref:`pretrained_models` + "``mfa configure``", "Configure MFA to use customized defaults for command line arguments", :ref:`configuration` + "``mfa history``", "List previous MFA commands run locally", Grapheme-to-phoneme @@ -60,7 +60,7 @@ Grapheme-to-phoneme .. csv-table:: :header: "Command", "Description", "Link" - :widths: 10, 110, 40 + :widths: 50, 110, 40 - "g2p", "Use a G2P model to generate a pronunciation dictionary", :ref:`g2p_dictionary_generating` - "train_g2p", "Train a G2P model from a pronunciation dictionary", :ref:`g2p_model_training` + "``mfa g2p``", "Use a G2P model to generate a pronunciation dictionary", :ref:`g2p_dictionary_generating` + "``mfa train_g2p``", "Train a G2P model from a pronunciation dictionary", :ref:`g2p_model_training` diff --git a/docs/source/user_guide/configuration/acoustic_model_adapt.rst b/docs/source/user_guide/configuration/acoustic_model_adapt.rst new file mode 100644 index 00000000..8655cae9 --- /dev/null +++ b/docs/source/user_guide/configuration/acoustic_model_adapt.rst @@ -0,0 +1,13 @@ + +.. _configuration_adapting: + +Acoustic model adaptation options +================================= + +For the Kaldi recipe that monophone training is based on, see :kaldi_steps:`train_map`. + + +.. csv-table:: + :header: "Parameter", "Default value", "Notes" + + "mapping_tau", 20, "smoothing constant used in MAP estimation, corresponds to the number of 'fake counts' that we add for the old model. Larger tau corresponds to less aggressive re-estimation, and more smoothing. You might also want to try 10 or 15." diff --git a/docs/source/user_guide/configuration/align.rst b/docs/source/user_guide/configuration/acoustic_modeling.rst similarity index 52% rename from docs/source/user_guide/configuration/align.rst rename to docs/source/user_guide/configuration/acoustic_modeling.rst index cf23a5b8..c07ead87 100644 --- a/docs/source/user_guide/configuration/align.rst +++ b/docs/source/user_guide/configuration/acoustic_modeling.rst @@ -1,80 +1,24 @@ -.. _configuration_alignment: +.. _configuration_acoustic_modeling: -*********************** -Alignment Configuration -*********************** +******************************* +Acoustic model training options +******************************* -Global options -============== - -These options are used for aligning the full dataset (and as part of training). Increasing the values of them will -allow for more relaxed restrictions on alignment. Relaxing these restrictions can be particularly helpful for certain -kinds of files that are quite different from the training dataset (i.e., single word production data from experiments, -or longer stretches of audio). - - -.. csv-table:: - :header: "Parameter", "Default value", "Notes" - :escape: ' - - "beam", 10, "Initial beam width to use for alignment" - "retry_beam", 40, "Beam width to use if initial alignment fails" - "transition_scale", 1.0, "Multiplier to scale transition costs" - "acoustic_scale", 0.1, "Multiplier to scale acoustic costs" - "self_loop_scale", 0.1, "Multiplier to scale self loop costs" - "boost_silence", 1.0, "1.0 is the value that does not affect probabilities" - "punctuation", "、。।,@<>'"'(),.:;¿?¡!\\&%#*~【】,…‥「」『』〝〟″⟨⟩♪・‹›«»~′$+=", "Characters to treat as punctuation and strip from around words" - "clitic_markers", "'''’", "Characters to treat as clitic markers, will be collapsed to the first character in the string" - "compound_markers", "\-", "Characters to treat as marker in compound words (i.e., doesn't need to be preserved like for clitics)" - "multilingual_ipa", False, "Flag for enabling multilingual IPA mode, see :ref:`multilingual_ipa` for more details" - "strip_diacritics", "/iː/ /iˑ/ /ĭ/ /i̯/ /t͡s/ /t‿s/ /t͜s/ /n̩/", "IPA diacritics to strip in multilingual IPA mode (phone symbols for proper display, when specifying them just have the diacritic)" - "digraphs", "[dt][szʒʃʐʑʂɕç], [aoɔe][ʊɪ]", "Digraphs to split up in multilingual IPA mode" - - -.. _feature_config: - -Feature Configuration -===================== - -This section is only relevant for training, as the trained model will contain extractors and feature specification for -what it requires. - -.. csv-table:: - :header: "Parameter", "Default value", "Notes" +.. note:: - "type", "mfcc", "Currently only MFCCs are supported" - "use_energy", "False", "Use energy in place of first MFCC" - "frame_shift", 10, "In milliseconds, determines time resolution" - "snip_edges", True, "Should provide better time resolution in alignment" - "pitch", False, "Currently not implemented" - "low_frequency", 20, "Frequency cut off for feature generation" - "high_frequency", 7800, "Frequency cut off for feature generation" - "sample_frequency", 16000, "Sample rate to up- or down-sample to" - "allow_downsample", True, "Flag for allowing down-sampling" - "allow_upsample", True, "Flag for allowing up-sampling" - "splice_left_context", 3, "Frame width for generating LDA transforms" - "splice_right_context", 3, "Frame width for generating LDA transforms" - "use_mp", True, "Flag for whether to use multiprocessing feature generation" - -.. _training_config: - -Training configuration -====================== + See :ref:`configuration_global` for options relating to the alignment steps Global alignment options can be overwritten for each trainer (i.e., different beam settings at different stages of training). .. note:: - Subsets are created by sorting the utterances by length, taking a larger subset (10 times the specified subset amount) - and then randomly sampling the specified subset amount from this larger subset. Utterances with transcriptions that - are only one word long are ignored. + Subsets are created by sorting the utterances by length, taking a larger subset (10 times the specified subset amount) and then randomly sampling the specified subset amount from this larger subset. Utterances with transcriptions that are only one word long are ignored. Monophone Configuration ----------------------- -For the Kaldi recipe that monophone training is based on, see -https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_mono.sh +For the Kaldi recipe that monophone training is based on, see :kaldi_steps:`train_mono`. .. csv-table:: @@ -91,11 +35,10 @@ quarter of training will perform realignment every iteration, the second quarter and the final two quarters will perform realignment every third iteration. -Triphone Configuration ----------------------- +Triphone training options +------------------------- -For the Kaldi recipe that triphone training is based on, see -https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_deltas.sh +For the Kaldi recipe that triphone training is based on, see :kaldi_steps:`train_deltas`. .. csv-table:: :header: "Parameter", "Default value", "Notes" @@ -108,11 +51,10 @@ https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_deltas.sh "cluster_threshold", -1, "Threshold for clustering leaves in decision tree" -LDA Configuration ------------------ +LDA training options +-------------------- -For the Kaldi recipe that LDA training is based on, see -https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_lda_mllt.sh +For the Kaldi recipe that LDA training is based on, see:kaldi_steps:`train_lda_mllt`. .. csv-table:: :header: "Parameter", "Default value", "Notes" @@ -130,11 +72,10 @@ https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_lda_mllt.s LDA estimation will be performed every other iteration for the first quarter of iterations, and then one final LDA estimation will be performed halfway through the training iterations. -Speaker-adapted training (SAT) configuration --------------------------------------------- +Speaker-adapted training (SAT) options +-------------------------------------- -For the Kaldi recipe that SAT training is based on, see -https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_sat.sh +For the Kaldi recipe that SAT training is based on, see:kaldi_steps:`train_sat`. .. csv-table:: :header: "Parameter", "Default value", "Notes" @@ -158,6 +99,8 @@ will be performed halfway through the training iterations. Default training config file ---------------------------- +The below configuration file shows the equivalent of the current 2.0 training regime, mostly as an example of what configuration options are available and how they progress through the overall training. + .. code-block:: yaml beam: 10 @@ -198,27 +141,24 @@ Default training config file max_gaussians: 15000 power: 0.2 silence_weight: 0.0 - fmllr_update_type: "diag" + fmllr_update_type: "full" subset: 10000 - features: - lda: true - sat: num_leaves: 4200 max_gaussians: 40000 power: 0.2 silence_weight: 0.0 - fmllr_update_type: "diag" + fmllr_update_type: "full" subset: 30000 - features: - lda: true - fmllr: true .. _1.0_training_config: Training configuration for 1.0 ------------------------------ +The below configuration matches the training procedure used in models trained in version 1.0. Note the lack of an LDA block, and only one SAT training block, as well as the lack of subsets in initial training blocks to speed up overall training. + .. code-block:: yaml beam: 10 @@ -250,14 +190,3 @@ Training configuration for 1.0 silence_weight: 0.0 cluster_threshold: 100 fmllr_update_type: "full" - - -.. _align_config: - -Align configuration -=================== - -.. code-block:: yaml - - beam: 10 - retry_beam: 40 diff --git a/docs/source/user_guide/configuration/dictionary.rst b/docs/source/user_guide/configuration/dictionary.rst deleted file mode 100644 index 25f7d836..00000000 --- a/docs/source/user_guide/configuration/dictionary.rst +++ /dev/null @@ -1,32 +0,0 @@ - -.. _configuration_dictionary: - -************************ -Dictionary Configuration -************************ - -Text normalization and parsing of words from text can be configured in yaml configuration files. Punctuation is stripped from all words, so if a character is part of a language's orthography, modifying the :code:`punctuation` parameter to exclude that character would keep that character in the words. See more examples of how these :code:`punctuation`, :code:`clitic_markers`, and :code:`compound_markers` are used in :ref:`text_normalization`. - -The :code:`multilingual_ipa`, :code:`strip_diacritics`, and :code:`digraphs` are all used as part of :ref:`multilingual_ipa`. - -.. csv-table:: - :header: "Parameter", "Default value", "Notes" - :escape: ' - - "oov_word", "", "Internal word symbol to use for out of vocabulary items" - "oov_phone", "spn", "Internal phone symbol to use for out of vocabulary items" - "silence_word", "!sil", "Internal word symbol to use initial silence" - "nonoptional_silence_phone", "sil", "Internal phone symbol to use initial silence" - "optional_silence_phone", "sp", "Internal phone symbol to use optional silence in the middle of utterances" - "position_dependent_phones", "True", "Flag for whether phones should mark their position in the word as part of the phone symbol internally" - "num_silence_states", "5", "Number of states to use for silence phones" - "num_non_silence_states", "3", "Number of states to use for non-silence phones" - "shared_silence_phones", "True", "Flag for whether to share silence phone models" - "silence_probability", "0.5", "Probability of inserting silence around and within utterances, setting to 0 removes silence modelling" - "punctuation", "、。।,@<>'"'(),.:;¿?¡!\\&%#*~【】,…‥「」『』〝〟″⟨⟩♪・‹›«»~′$+=", "Characters to treat as punctuation and strip from around words" - "clitic_markers", "'''’", "Characters to treat as clitic markers, will be collapsed to the first character in the string" - "compound_markers", "\-", "Characters to treat as marker in compound words (i.e., doesn't need to be preserved like for clitics)" - "multilingual_ipa", False, "Flag for enabling multilingual IPA mode, see :ref:`multilingual_ipa` for more details" - "strip_diacritics", "/iː/ /iˑ/ /ĭ/ /i̯/ /t͡s/ /t‿s/ /t͜s/ /n̩/", "IPA diacritics to strip in multilingual IPA mode (phone symbols for proper display, when specifying them just have the diacritic)" - "digraphs", "[dt][szʒʃʐʑʂɕç], [aoɔe][ʊɪ]", "Digraphs to split up in multilingual IPA mode" - "brackets", "('[', ']'), ('{', '}'), ('<', '>'), ('(', ')')", "Punctuation to keep as bracketing a whole word, i.e., a restart, disfluency, etc" diff --git a/docs/source/user_guide/configuration/g2p.rst b/docs/source/user_guide/configuration/g2p.rst index e1ff0af1..9f71f7c0 100644 --- a/docs/source/user_guide/configuration/g2p.rst +++ b/docs/source/user_guide/configuration/g2p.rst @@ -22,8 +22,8 @@ Global options .. _train_g2p_config: -Train G2P Configuration -======================= +G2P training options +==================== In addition to the parameters above, the following parameters are used as part of training a G2P model. @@ -41,11 +41,13 @@ In addition to the parameters above, the following parameters are used as part o "pruning_method", "relative_entropy", "Pruning method for pruning the ngram model" "model_size", 1000000, "Target number of ngrams for pruning" +Example G2P configuration files +=============================== .. _default_train_g2p_config: Default G2P training config file -================================ +-------------------------------- .. code-block:: yaml @@ -66,11 +68,10 @@ Default G2P training config file model_size: 1000000 - .. _default_g2p_config: -G2P generation configuration file -================================= +Default dictionary generation config file +----------------------------------------- .. code-block:: yaml diff --git a/docs/source/user_guide/configuration/global.rst b/docs/source/user_guide/configuration/global.rst new file mode 100644 index 00000000..de24432e --- /dev/null +++ b/docs/source/user_guide/configuration/global.rst @@ -0,0 +1,84 @@ + +.. _configuration_global: + +************** +Global Options +************** + +These options are used for aligning the full dataset (and as part of training). Increasing the values of them will +allow for more relaxed restrictions on alignment. Relaxing these restrictions can be particularly helpful for certain +kinds of files that are quite different from the training dataset (i.e., single word production data from experiments, +or longer stretches of audio). + + +.. csv-table:: + :header: "Parameter", "Default value", "Notes" + :escape: ' + + "beam", 10, "Initial beam width to use for alignment" + "retry_beam", 40, "Beam width to use if initial alignment fails" + "transition_scale", 1.0, "Multiplier to scale transition costs" + "acoustic_scale", 0.1, "Multiplier to scale acoustic costs" + "self_loop_scale", 0.1, "Multiplier to scale self loop costs" + "boost_silence", 1.0, "1.0 is the value that does not affect probabilities" + +.. _feature_config: + +Feature Configuration +===================== + +This section is only relevant for training, as the trained model will contain extractors and feature specification for what they requires. + +.. csv-table:: + :header: "Parameter", "Default value", "Notes" + + "feature_type", "mfcc", "Currently only MFCCs are supported" + "use_energy", "False", "Use energy in place of first MFCC" + "frame_shift", 10, "In milliseconds, determines time resolution" + "snip_edges", True, "Should provide better time resolution in alignment" + "pitch", False, "Currently not implemented" + "low_frequency", 20, "Frequency cut off for feature generation" + "high_frequency", 7800, "Frequency cut off for feature generation" + "sample_frequency", 16000, "Sample rate to up- or down-sample to" + "allow_downsample", True, "Flag for allowing down-sampling" + "allow_upsample", True, "Flag for allowing up-sampling" + "uses_cmvn", True, "Flag for whether to use CMVN" + "uses_deltas", True, "Flag for whether to use delta features" + "uses_splices", False, "Flag for whether to use splices and LDA transformations" + "splice_left_context", 3, "Frame width for generating LDA transforms" + "splice_right_context", 3, "Frame width for generating LDA transforms" + "uses_speaker_adaptation", False, "Flag for whether to use speaker adaptation" + "fmllr_update_type", "full", "Type of fMLLR estimation" + "silence_weight", 0.0, "Weight of silence in calculating LDA or fMLLR" + + +.. _configuration_dictionary: + +Dictionary and text parsing options +=================================== + +This sections details configuration options related to how MFA normalizes text and performs dictionary look up. Punctuation is stripped from all words, so if a character is part of a language's orthography, modifying the :code:`punctuation` parameter to exclude that character would keep that character in the words. See more examples of how these :code:`punctuation`, :code:`clitic_markers`, and :code:`compound_markers` are used in :ref:`text_normalization`. + +The :code:`multilingual_ipa`, :code:`strip_diacritics`, and :code:`digraphs` are all used as part of :ref:`multilingual_ipa`. + +.. csv-table:: + :header: "Parameter", "Default value", "Notes" + :escape: ' + + "oov_word", "", "Internal word symbol to use for out of vocabulary items" + "oov_phone", "spn", "Internal phone symbol to use for out of vocabulary items" + "silence_word", "!sil", "Internal word symbol to use initial silence" + "nonoptional_silence_phone", "sil", "Internal phone symbol to use initial silence" + "optional_silence_phone", "sp", "Internal phone symbol to use optional silence in the middle of utterances" + "position_dependent_phones", "True", "Flag for whether phones should mark their position in the word as part of the phone symbol internally" + "num_silence_states", "5", "Number of states to use for silence phones" + "num_non_silence_states", "3", "Number of states to use for non-silence phones" + "shared_silence_phones", "True", "Flag for whether to share silence phone models" + "silence_probability", "0.5", "Probability of inserting silence around and within utterances, setting to 0 removes silence modelling" + "punctuation", "、。।,@<>'"'(),.:;¿?¡!\\&%#*~【】,…‥「」『』〝〟″⟨⟩♪・‹›«»~′$+=", "Characters to treat as punctuation and strip from around words" + "clitic_markers", "'''’", "Characters to treat as clitic markers, will be collapsed to the first character in the string" + "compound_markers", "\-", "Characters to treat as marker in compound words (i.e., doesn't need to be preserved like for clitics)" + "multilingual_ipa", False, "Flag for enabling multilingual IPA mode, see :ref:`multilingual_ipa` for more details" + "strip_diacritics", "/iː/ /iˑ/ /ĭ/ /i̯/ /t͡s/ /t‿s/ /t͜s/ /n̩/", "IPA diacritics to strip in multilingual IPA mode (phone symbols for proper display, when specifying them just have the diacritic)" + "digraphs", "[dt][szʒʃʐʑʂɕç], [aoɔe][ʊɪ]", "Digraphs to split up in multilingual IPA mode" + "brackets", "('[', ']'), ('{', '}'), ('<', '>'), ('(', ')')", "Punctuation to keep as bracketing a whole word, i.e., a restart, disfluency, etc" diff --git a/docs/source/user_guide/configuration/index.rst b/docs/source/user_guide/configuration/index.rst index 445d9bde..bba02121 100644 --- a/docs/source/user_guide/configuration/index.rst +++ b/docs/source/user_guide/configuration/index.rst @@ -5,25 +5,83 @@ Configuration ************* -Global configuration for MFA can be updated via the ``mfa configure`` subcommand. Once the command is called with a flag, it will set a default value for any future runs (though, you can overwrite most settings when you call other commands). +MFA root directory +================== + +MFA uses a temporary directory for commands that can be specified in running commands with ``--temp_directory`` (see below), and it also uses a directory to store global configuration settings and saved models. By default this root directory is ``~/Documents/MFA``, but if you would like to put this somewhere else, you can set the environment variable ``MFA_ROOT_DIR`` to use that. MFA will raise an error on load if it's unable to write to the root directory. -Command reference ------------------ +Global configuration +==================== + +Global configuration for MFA can be updated via the ``mfa configure`` subcommand. Once the command is called with a flag, it will set a default value for any future runs (though, you can overwrite most settings when you call other commands). -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: configure Configuring specific commands ============================= +MFA has the ability to customize various parameters that control aspects of data processing and workflows. These can be supplied via the command line like: + +.. code-block:: bash + + mfa align ... --beam 1000 + +The above command will set the beam width used in aligning to ``1000`` (and the retry beam width to 4000). This command is the equivalent of supplying a config file like the below via the ``--config_path``: + +.. code-block:: yaml + + beam: 1000 + +Supplying the above via: + +.. code-block:: bash + + mfa align ... --config_path config_above.yaml + +will also set the beam width to ``1000`` and retry beam width to ``4000`` as well. + +For simple settings, the command line argument approach can be good, but for more complex settings, the config yaml approach will allow you to specify things like aspects of training blocks or multilingual IPA flags: + +.. code-block:: yaml + + beam: 100 + retry_beam: 400 + + punctuation: ":,." + + multilingual_ipa: true + digraphs: + - "[dt][szʒʃʐʑʂɕç]" + - "[aoɔe][ʊɪ]" + + training: + - monophone: + num_iterations: 20 + max_gaussians: 500 + subset: 1000 + boost_silence: 1.25 + + - triphone: + num_iterations: 35 + num_leaves: 2000 + max_gaussians: 10000 + cluster_threshold: -1 + subset: 5000 + boost_silence: 1.25 + power: 0.25 + +You can then also override these options on the command like, i.e. ``--beam 10 --config_path config_above.yaml`` would reset the beam width to ``10``. Command line specified arguments always have higher priority over the parameters derived from a configuration yaml. + .. toctree:: - :maxdepth: 1 + :hidden: - dictionary.rst - align.rst - transcription.rst + global.rst + acoustic_modeling.rst + acoustic_model_adapt.rst + g2p.rst lm.rst + transcription.rst segment.rst ivector.rst - g2p.rst diff --git a/docs/source/user_guide/configuration/ivector.rst b/docs/source/user_guide/configuration/ivector.rst index 4bec1ed7..a31360c9 100644 --- a/docs/source/user_guide/configuration/ivector.rst +++ b/docs/source/user_guide/configuration/ivector.rst @@ -1,41 +1,59 @@ .. _configuration_ivector: -********************* -Ivector Configuration -********************* +********************************** +Ivector extractor training options +********************************** -For the Kaldi recipe that ivector extractor training is based on, see :kaldi_steps_sid:`train_diag_ubm` and :kaldi_steps_sid:`train_ivector_extractor`. +.. warning:: + + The current implementation of ivectors is a little spotty and there is a planned pass over the speaker diarization on the roadmap for 2.1. + +Diagonal UBM training +===================== + +For the Kaldi recipe that DUBM training is based on, see :kaldi_steps_sid:`train_diag_ubm`. + +.. csv-table:: + :header: "Parameter", "Default value", "Notes" + + "num_iterations", 4, "Number of iterations for training UBM" + "num_gselect", 30, "Number of Gaussian-selection indices to use while training" + "subsample", 5, "Subsample factor for feature frames" + "num_frames", 500000, "Number of frames to keep in memory for initialization" + "num_gaussians", 256, "Number of gaussians to use for DUBM training" + "num_iterations_init", 20, "Number of iteration to use when initializing UBM" + "initial_gaussian_proportion", 0.5, "Start with half the target number of Gaussians" + "min_gaussian_weight", 0.0001, "" + "remove_low_count_gaussians", True, "Flag for removing low count gaussians in the final round of training" + + +Ivector training +================ + +For the Kaldi recipe that ivector training is based on, see :kaldi_steps_sid:`train_ivector_extractor`. .. csv-table:: :header: "Parameter", "Default value", "Notes" - "ubm_num_iterations", 4, "Number of iterations for training UBM" - "ubm_num_gselect", 30, "Number of Gaussian-selection indices to use while training" - "ubm_num_frames", 500000, "Number of frames to keep in memory for initialization" - "ubm_num_gaussians", 256, "" - "ubm_num_iterations_init", 20, "Number of iteration to use when initializing UBM" - "ubm_initial_gaussian_proportion", 0.5, "Start with half the target number of Gaussians" - "ubm_min_gaussian_weight", 0.0001, "" - "ubm_remove_low_count_gaussians", True, "" "ivector_dimension", 128, "Dimension of extracted ivectors" "num_iterations", 10, "Number of training iterations" "num_gselect", 20, "Gaussian-selection using diagonal model: number of Gaussians to select" "posterior_scale", 1.0, "Scale on posteriors to correct for inter-frame correlation" - "silence_weight", 0.0, "" + "silence_weight", 0.0, "Weight of silence in calculating posteriors for ivector extraction" "min_post", 0.025, "Minimum posterior to use (posteriors below this are pruned out)" - "num_samples_for_weights", 3, "" "gaussian_min_count", 100, "" "subsample", 5, "Speeds up training (samples every Xth frame)" - "max_count", 100, "" - "apply_cmn", True, "Flag for whether to apply CMVN to input features" - + "max_count", 100, "The use of this option can make iVectors more consistent for different lengths of utterance, by scaling up the prior term when the data-count exceeds this value. The data-count is after posterior-scaling, so assuming the posterior-scale is 0.1, max_count=100 starts having effect after 1000 frames, or 10 seconds of data." + "uses_cmvn", True, "Flag for whether to apply CMVN to input features" .. _default_ivector_training_config: Default training config file ---------------------------- +The below configuration file shows the equivalent of the current 2.0 training regime, mostly as an example of what configuration options are available and how they progress through the overall training. + .. code-block:: yaml features: @@ -44,9 +62,15 @@ Default training config file frame_shift: 10 training: + - dubm: + num_iterations: 4 + num_gselect: 30 + num_gaussians: 256 + num_iterations_init: 20 - ivector: + ivector_dimension: 128 num_iterations: 10 - gaussian_min_count: 2 + gaussian_min_count: 100 silence_weight: 0.0 posterior_scale: 0.1 max_count: 100 diff --git a/docs/source/user_guide/configuration/lm.rst b/docs/source/user_guide/configuration/lm.rst index c27b8c0d..86f05114 100644 --- a/docs/source/user_guide/configuration/lm.rst +++ b/docs/source/user_guide/configuration/lm.rst @@ -1,21 +1,18 @@ -.. _lm_config: +.. _configuration_language_modeling: -**************************** -Language model configuration -**************************** +******************************* +Language model training options +******************************* -.. _train_lm_config: +See also the :ref:`configuration_dictionary` for the options that control how text is normalized and parsed. -Language model configuration -============================ .. csv-table:: :header: "Parameter", "Default value", "Notes" "order", 3, "Order of language model" "method", kneser_ney, "Method for smoothing" - "prune", false, "Flag for whether to output pruned models as well" "prune_thresh_small", 0.0000003, "Threshold for pruning a small model, only used if ``prune`` is true" "prune_thresh_medium", 0.0000001, "Threshold for pruning a medium model, only used if ``prune`` is true" @@ -26,6 +23,5 @@ Default language model config order: 3 method: kneser_ney - prune: false prune_thresh_small: 0.0000003 prune_thresh_medium: 0.0000001 diff --git a/docs/source/user_guide/configuration/segment.rst b/docs/source/user_guide/configuration/segment.rst index da73c16d..c211bfb3 100644 --- a/docs/source/user_guide/configuration/segment.rst +++ b/docs/source/user_guide/configuration/segment.rst @@ -1,9 +1,9 @@ -.. _configuration_segments: +.. _configuration_segmentation: -***************************** -Create segments configuration -***************************** +******************** +Segmentation options +******************** .. csv-table:: @@ -16,8 +16,8 @@ Create segments configuration .. _default_segment_config: -Default training config file ----------------------------- +Default segmentation config file +-------------------------------- .. code-block:: yaml diff --git a/docs/source/user_guide/configuration/transcription.rst b/docs/source/user_guide/configuration/transcription.rst index 2a0d3b5e..0b4d00fc 100644 --- a/docs/source/user_guide/configuration/transcription.rst +++ b/docs/source/user_guide/configuration/transcription.rst @@ -1,9 +1,9 @@ .. _transcribe_config: -************************* -Transcriber configuration -************************* +********************* +Transcription options +********************* .. csv-table:: :header: "Parameter", "Default value", "Notes" @@ -13,9 +13,9 @@ Transcriber configuration "lattice_beam", 6, "Beam width for decoding lattices" "acoustic_scale", 0.083333, "Multiplier to scale acoustic costs" "silence_weight", 0.01, "Weight on silence in fMLLR estimation" - "fmllr", true, "Flag for whether to perform speaker adaptation" - "first_beam", 10.0, "Beam for decoding in initial speaker-independent pass, only used if ``fmllr`` is true" - "first_max_active", 2000, "Max active for decoding in initial speaker-independent pass, only used if ``fmllr`` is true" + "uses_speaker_adaptation", true, "Flag for whether to perform speaker adaptation" + "first_beam", 10.0, "Beam for decoding in initial speaker-independent pass, only used if ``uses_speaker_adaptation`` is true" + "first_max_active", 2000, "Max active for decoding in initial speaker-independent pass, only used if ``uses_speaker_adaptation`` is true" "fmllr_update_type", "full", "Type of fMLLR estimation" Default transcriber config diff --git a/docs/source/user_guide/data_validation.rst b/docs/source/user_guide/data_validation.rst index 5d8abb3c..6827428f 100644 --- a/docs/source/user_guide/data_validation.rst +++ b/docs/source/user_guide/data_validation.rst @@ -36,3 +36,4 @@ Command reference .. autoprogram:: montreal_forced_aligner.command_line.mfa:parser :prog: mfa :start_command: validate + :groups: diff --git a/docs/source/user_guide/formats/corpus_structure.rst b/docs/source/user_guide/formats/corpus_structure.rst index 41ad926e..b8abeb00 100644 --- a/docs/source/user_guide/formats/corpus_structure.rst +++ b/docs/source/user_guide/formats/corpus_structure.rst @@ -141,4 +141,4 @@ MFA will automatically convert higher bit depths via the :code:`sox` conda packa Duration ======== -In general, audio segments (sound files for Prosodylab-aligner format or intervals for the TextGrid format) should be less than 30 seconds for best performance (the shorter the faster). We recommend using breaks like breaths or silent pauses (i.e., not associated with a stop closure) to separate the audio segments. For longer segments, setting the beam and retry beam higher than their defaults will allow them to be aligned. The default beam/retry beam is very conservative 10/40, so something like 400/1000 will allow for much longer sequences to be aligned. Though also note that the higher the beam value, the slower alignment will be as well. See :ref:`configuration_alignment` for more details. +In general, audio segments (sound files for Prosodylab-aligner format or intervals for the TextGrid format) should be less than 30 seconds for best performance (the shorter the faster). We recommend using breaks like breaths or silent pauses (i.e., not associated with a stop closure) to separate the audio segments. For longer segments, setting the beam and retry beam higher than their defaults will allow them to be aligned. The default beam/retry beam is very conservative 10/40, so something like 400/1000 will allow for much longer sequences to be aligned. Though also note that the higher the beam value, the slower alignment will be as well. See :ref:`configuration_global` for more details. diff --git a/docs/source/user_guide/formats/dictionary.rst b/docs/source/user_guide/formats/dictionary.rst index b7648724..05dbb0c7 100644 --- a/docs/source/user_guide/formats/dictionary.rst +++ b/docs/source/user_guide/formats/dictionary.rst @@ -76,7 +76,7 @@ With a pronunciation of: S E T E A N S E The key point to note is that the pronunciation of the clitic ``c'`` is ``S`` -and the pronunciation of the letter ``c`` in French is ``S A``. +and the pronunciation of the letter ``c`` in French is ``S E``. The algorithm will try to associate the clitic marker with either the element before (as for French clitics) or the element after (as for English clitics diff --git a/docs/source/user_guide/glossary.rst b/docs/source/user_guide/glossary.rst new file mode 100644 index 00000000..93b8899f --- /dev/null +++ b/docs/source/user_guide/glossary.rst @@ -0,0 +1,39 @@ + +Glossary +======== + +.. glossary:: + :sorted: + + Acoustic model + Acoustic models + GMM-HMM + Acoustic models calculate how likely a phone is given the acoustic features (and previous and following states). The architecture used in MFA is Gaussian mixture models with hidden markov models (GMM-HMM). The GMM component models the distributions of acoustic features per phone (well, really many distributions that map to phones in a many-to-many mapping), and then the HMM component tracks the transition probabilities between states. State of the art approaches to acoustic modeling used deep neural networks, either in a hybrid DNN-HMM framework, or more recently, doing away with phone labels entirely to just model acoustics to words or subwords. + + Language model + Language models + Language models calculate how likely a string of words is to occur, given the data they were trained on. They are typically generated over large text corpora. The architecture used in MFA is that of an N-Gram model (typically trigram), with a window of N-1 previous words that predict the current word. State of the art methods are typically RNN or transformer based approaches. + + Pronunciation dictionary + Pronunciation dictionaries + Pronunciation dictionaries are used to map words to phones that are aligned. The phone set used in the dictionary must match that of the :term:`acoustic model` used, since the acoustic model will not be able to estimate probabilities for a phone label if it wasn't trained on it. :term:`G2P models` can be used to generate pronunciation dictionaries. + + Grapheme-to-Phoneme + G2P model + G2P models + G2P models generate sequences of phones based on an orthographic representation. Typically, the more transparent the orthography, the better the pronunciations generated. The architecture used in MFA is that of a weight Finite State Transducer (wFST), based on :xref:`pynini`. More state of the art approaches use DNNs in a sequence-to-sequence task to get better performance, either RNNs or transformers. + + TextGrid + TextGrids + File format that can be used to mark time aligned utterances, and is the output format for alignments in MFA. See :xref:`praat` for more details about TextGrids and their use in phonetics. + + MFCCs + :abbr:`Mel-frequency cepstrum coefficients (MFCCs)` are the industry standard for acoustic features. The process involves windowing the acoustic waveform, scaling the frequencies into the Mel space (an auditory representation that gives more weight to lower frequencies over higher frequencies), and then performs a :abbr:`discrete cosine transform (DCT)` on the values in each filter bank to get orthogonal coefficients. There was a trend around 2015-2018 to use acoustic features that were more raw (i.e., not transformed to the Mel space, or the waveform directly), but in general most recent state of the art systems still use MFCC features. + + Pronunciation probabilities + Pronunciation probabilities in dictionaries allow for certain spoken forms to be more likely, rather than just assigning equal weight to all pronunciation variants. + + Ivectors + Ivector extractor + Ivector extractors + Ivectors are generated based off acoustic features like MFCCs, but are trained alongside a universal background model to be a representation of a speaker. diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst index 238096cf..040fc29c 100644 --- a/docs/source/user_guide/index.rst +++ b/docs/source/user_guide/index.rst @@ -127,3 +127,4 @@ We acknowledge funding from Social Sciences and Humanities Research Council (SSH workflows/index configuration/index models/index + glossary diff --git a/docs/source/user_guide/models/acoustic.rst b/docs/source/user_guide/models/acoustic.rst index f2e7d9e6..050262c1 100644 --- a/docs/source/user_guide/models/acoustic.rst +++ b/docs/source/user_guide/models/acoustic.rst @@ -12,7 +12,7 @@ Pretrained acoustic models ************************** -As part of using the Montreal Forced Aligner in our own research, we have trained acoustic models for a number of languages. +As part of using the Montreal Forced Aligner in our own research, we have trained :term:`acoustic models` for a number of languages. If you would like to use them, please download them below. Please note the dictionary that they were trained with to see more information about the phone set. When using these with a pronunciation dictionary, the phone sets must be compatible. If the orthography of the language is transparent, it is likely that we have a G2P model that can be used diff --git a/docs/source/user_guide/models/dictionary.rst b/docs/source/user_guide/models/dictionary.rst index 919b282f..cbb99cbe 100644 --- a/docs/source/user_guide/models/dictionary.rst +++ b/docs/source/user_guide/models/dictionary.rst @@ -13,7 +13,7 @@ Available pronunciation dictionaries ************************************ -Any of the following pronunciation dictionaries can be downloaded with the command :code:`mfa model download dictionary `. You +Any of the following :term:`pronunciation dictionaries` can be downloaded with the command :code:`mfa model download dictionary `. You can get a full list of the currently available dictionaries via :code:`mfa model download dictionary`. New dictionaries contributed by users will be periodically added. If you would like to contribute your dictionaries, please contact Michael McAuliffe at michael.e.mcauliffe@gmail.com. diff --git a/docs/source/user_guide/models/g2p.rst b/docs/source/user_guide/models/g2p.rst index 56187ebb..04e010f5 100644 --- a/docs/source/user_guide/models/g2p.rst +++ b/docs/source/user_guide/models/g2p.rst @@ -16,7 +16,7 @@ Pretrained G2P models ********************* -Included with MFA is a separate tool to generate a dictionary from a preexisting model. This should be used if you're +Included with MFA is a separate tool to generate a dictionary from a preexisting :term:`G2P model`. This should be used if you're aligning a dataset for which you have no pronunciation dictionary or the orthography is very transparent. We have pretrained models for several languages below. diff --git a/docs/source/user_guide/models/index.rst b/docs/source/user_guide/models/index.rst index 6ff48f68..0543a59d 100644 --- a/docs/source/user_guide/models/index.rst +++ b/docs/source/user_guide/models/index.rst @@ -17,7 +17,7 @@ for downloading these is :code:`mfa model download ` where ``model_t Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: model diff --git a/docs/source/user_guide/models/lm.rst b/docs/source/user_guide/models/lm.rst index 4bfc5a57..e9ce760b 100644 --- a/docs/source/user_guide/models/lm.rst +++ b/docs/source/user_guide/models/lm.rst @@ -11,7 +11,7 @@ Pretrained language models ************************** -There are several places that contain pretrained language models that can be imported to MFA. +There are several places that contain pretrained :term:`language models` that can be imported to MFA. .. csv-table:: :header: "Source", "Language", "Link" diff --git a/docs/source/user_guide/workflows/adapt_acoustic_model.rst b/docs/source/user_guide/workflows/adapt_acoustic_model.rst new file mode 100644 index 00000000..ad7a751d --- /dev/null +++ b/docs/source/user_guide/workflows/adapt_acoustic_model.rst @@ -0,0 +1,25 @@ +.. _adapt_acoustic_model: + +Adapt acoustic model to new data ``(mfa adapt)`` +================================================ + +A recent 2.0 functionality for MFA is to adapt pretrained :term:`acoustic models` to a new dataset. MFA will first align the dataset using the pretrained model, and then update the acoustic model's GMM means with those generated by the data. See :kaldi_steps:`train_map` for the Kaldi script this functionality corresponds to. As part of the adaptation process, MFA can generate final alignments and export these files if an output directory is specified in the command. + + +Command reference +----------------- + +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() + :prog: mfa + :start_command: adapt + +Configuration reference +----------------------- + +- :ref:`configuration_global` +- :ref:`configuration_adapting` + +API reference +------------- + +- :class:`~montreal_forced_aligner.alignment.AdaptingAligner` diff --git a/docs/source/user_guide/workflows/aligning/adapt_acoustic_model.rst b/docs/source/user_guide/workflows/aligning/adapt_acoustic_model.rst deleted file mode 100644 index d52e28e5..00000000 --- a/docs/source/user_guide/workflows/aligning/adapt_acoustic_model.rst +++ /dev/null @@ -1,15 +0,0 @@ -.. _adapt_acoustic_model: - -*********************************** -Adapting acoustic model to new data -*********************************** - -A recent 2.0 functionality for MFA is to adapt pretrained models to a new dataset. MFA will first align the dataset using the pretrained model, and then perform a couple of rounds of speaker-adaptation training. - - -Command reference ------------------ - -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser - :prog: mfa - :start_command: adapt diff --git a/docs/source/user_guide/workflows/aligning/index.rst b/docs/source/user_guide/workflows/aligning/index.rst deleted file mode 100644 index ecc53fbe..00000000 --- a/docs/source/user_guide/workflows/aligning/index.rst +++ /dev/null @@ -1,13 +0,0 @@ - -.. _aligning: - -********************* -Generating alignments -********************* - - -.. toctree:: - :maxdepth: 3 - - adapt_acoustic_model.rst - pretrained.rst diff --git a/docs/source/user_guide/workflows/aligning/pretrained.rst b/docs/source/user_guide/workflows/aligning/pretrained.rst deleted file mode 100644 index 36ba8e67..00000000 --- a/docs/source/user_guide/workflows/aligning/pretrained.rst +++ /dev/null @@ -1,14 +0,0 @@ - -.. _pretrained_alignment: - -************************************ -Align with pretrained acoustic model -************************************ - - -Command reference ------------------ - -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser - :prog: mfa - :start_command: align diff --git a/docs/source/user_guide/workflows/alignment.rst b/docs/source/user_guide/workflows/alignment.rst new file mode 100644 index 00000000..756f5f7c --- /dev/null +++ b/docs/source/user_guide/workflows/alignment.rst @@ -0,0 +1,24 @@ + +.. _pretrained_alignment: + +Align with an acoustic model ``(mfa align)`` +============================================ + +This is the primary workflow of MFA, where you can use pretrained :term:`acoustic models` to align your dataset. There are a number of :ref:`pretrained_acoustic_models` to use, but you can also adapt a pretrained model to your data (see :ref:`adapt_acoustic_model`) or train an acoustic model from scratch using your dataset (see :ref:`train_acoustic_model`). + +Command reference +----------------- + +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() + :prog: mfa + :start_command: align + +Configuration reference +----------------------- + +- :ref:`configuration_global` + +API reference +------------- + +- :ref:`alignment_api` diff --git a/docs/source/user_guide/workflows/anchor.rst b/docs/source/user_guide/workflows/anchor.rst index d4bc2e0a..d2c7eb7d 100644 --- a/docs/source/user_guide/workflows/anchor.rst +++ b/docs/source/user_guide/workflows/anchor.rst @@ -3,23 +3,21 @@ .. _anchor: -**************** -Anchor annotator -**************** +Anchor annotator ``(mfa anchor)`` +================================= The Anchor Annotator is a GUI utility for MFA that allows for users to modify transcripts and add/change entries in the pronunciation dictionary to interactively fix out of vocabulary issues. .. attention:: - Anchor is under development and is currently pre-alpha. Use at your own risk and please use version control - or back up any critical data. + Anchor is under development and is currently pre-alpha. Use at your own risk and please use version control or back up any critical data. To use the annotator, first install the anchor subpackage: .. code-block:: - pip install montreal-forced-aligner[anchor] + conda install montreal-forced-aligner[anchor] This will install MFA if hasn't been along with all the packages that Anchor requires. Once installed, Anchor can be started with the following MFA subcommand: diff --git a/docs/source/user_guide/workflows/classify_speakers.rst b/docs/source/user_guide/workflows/classify_speakers.rst index 7a309687..ff170afa 100644 --- a/docs/source/user_guide/workflows/classify_speakers.rst +++ b/docs/source/user_guide/workflows/classify_speakers.rst @@ -1,15 +1,17 @@ .. _classify_speakers: -********************** -Speaker classification -********************** +Cluster speakers ``(mfa classify_speakers)`` +============================================ -The Montreal Forced Aligner can use trained ivector models (see :ref:`train_ivector` for more information about training -these models) to classify or cluster utterances according to speakers. +The Montreal Forced Aligner can use trained ivector models (see :ref:`train_ivector` for more information about trainingthese models) to classify or cluster utterances according to speakers. + +.. warning:: + + This feature is not fully implemented, and is still under construction. Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: classify_speakers diff --git a/docs/source/user_guide/workflows/corpus_creation.rst b/docs/source/user_guide/workflows/corpus_creation.rst index 3b6ca42f..652ff816 100644 --- a/docs/source/user_guide/workflows/corpus_creation.rst +++ b/docs/source/user_guide/workflows/corpus_creation.rst @@ -4,8 +4,7 @@ Corpus creation utilities ************************* -MFA now contains several command line utilities for helping to create corpora from scratch. The main workflow is as -follows: +MFA now contains several command line utilities for helping to create corpora from scratch. The main workflow is as follows: 1. If the corpus made up of long sound file that need segmenting, :ref:`create_segments` 2. If the corpus does not contain transcriptions, transcribe utterances using existing acoustic models, diff --git a/docs/source/user_guide/workflows/create_segments.rst b/docs/source/user_guide/workflows/create_segments.rst index 26c4dd19..db99dbe1 100644 --- a/docs/source/user_guide/workflows/create_segments.rst +++ b/docs/source/user_guide/workflows/create_segments.rst @@ -1,8 +1,7 @@ .. _create_segments: -*************** -Create segments -*************** +Create segments ``(mfa create_segments)`` +========================================= The Montreal Forced Aligner can use Voice Activity Detection (VAD) capabilities from Kaldi to generate segments from a longer sound file. @@ -10,13 +9,23 @@ a longer sound file. .. note:: The default configuration for VAD uses configuration values based on quiet speech. The algorithm is based on energy, - so if your recordings are more noisy, you may need to adjust the configuration. See :ref:`configuration_segments` + so if your recordings are more noisy, you may need to adjust the configuration. See :ref:`configuration_segmentation` for more information on changing these parameters. Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: create_segments + +Configuration reference +----------------------- + +- :ref:`configuration_segmentation` + +API reference +------------- + +- :ref:`segmentation_api` diff --git a/docs/source/user_guide/workflows/g2p/dictionary_generating.rst b/docs/source/user_guide/workflows/dictionary_generating.rst similarity index 78% rename from docs/source/user_guide/workflows/g2p/dictionary_generating.rst rename to docs/source/user_guide/workflows/dictionary_generating.rst index 55629b14..faecd6db 100644 --- a/docs/source/user_guide/workflows/g2p/dictionary_generating.rst +++ b/docs/source/user_guide/workflows/dictionary_generating.rst @@ -2,9 +2,8 @@ .. _g2p_dictionary_generating: -*********************** -Generating a dictionary -*********************** +Generate a new pronunciation dictionary ``(mfa g2p)`` +===================================================== We have trained several G2P models that are available for download (:ref:`pretrained_g2p`). @@ -34,6 +33,17 @@ See :ref:`dict_generating_example` for an example of how to use G2P functionalit Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: g2p + +Configuration reference +----------------------- + +- :ref:`configuration_g2p` +- :ref:`configuration_dictionary` + +API reference +------------- + +- :ref:`g2p_generate_api` diff --git a/docs/source/user_guide/workflows/g2p/index.rst b/docs/source/user_guide/workflows/g2p/index.rst deleted file mode 100644 index 6a41fcbc..00000000 --- a/docs/source/user_guide/workflows/g2p/index.rst +++ /dev/null @@ -1,20 +0,0 @@ -.. _g2p: - -************************* -Grapheme-to-Phoneme (G2P) -************************* - -There are many cases where a language's orthography is transparent, and creating an exhaustive list of all words in a corpus -is doable by rule rather than just listing. For these cases, we offer pretrained grapheme-to-phoneme (G2P) models, as -well as a way to train new G2P models. - -Currently, the way unknown symbols are handled is not perfect. If an unknown symbol is found, it is skipped and the pronunciation -for the rest of the orthography is generated. Please be careful when using this system for languages with logographic writing -systems such as Chinese or Japanese where unknown symbols are likely given the number of distinct characters, and be sure to -always check the resulting dictionary carefully before potentially propagating errors into the alignment. - -.. toctree:: - :maxdepth: 3 - - dictionary_generating.rst - model_training.rst diff --git a/docs/source/user_guide/workflows/g2p/model_training.rst b/docs/source/user_guide/workflows/g2p_train.rst similarity index 70% rename from docs/source/user_guide/workflows/g2p/model_training.rst rename to docs/source/user_guide/workflows/g2p_train.rst index d2852f33..49fbf367 100644 --- a/docs/source/user_guide/workflows/g2p/model_training.rst +++ b/docs/source/user_guide/workflows/g2p_train.rst @@ -3,11 +3,10 @@ .. _g2p_model_training: -************************ -Training a new G2P model -************************ +Train a new G2P model ``(mfa train_g2p)`` +========================================= -Another tool included with MFA allows you to train a G2P (Grapheme to Phoneme) model automatically from a given +Another tool included with MFA allows you to train a :term:`G2P model` from a given pronunciation dictionary. This type of model can be used for :ref:`g2p_dictionary_generating`. It requires a pronunciation dictionary with each line consisting of the orthographic transcription followed by the @@ -25,6 +24,19 @@ See :ref:`g2p_model_training_example` for an example of how to train a G2P model Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: train_g2p + +Configuration reference +----------------------- + +- :ref:`configuration_dictionary` +- :ref:`configuration_g2p` + + - :ref:`train_g2p_config` + +API reference +------------- + +- :ref:`g2p_modeling_api` diff --git a/docs/source/user_guide/workflows/index.rst b/docs/source/user_guide/workflows/index.rst index 01216529..fa2d2475 100644 --- a/docs/source/user_guide/workflows/index.rst +++ b/docs/source/user_guide/workflows/index.rst @@ -1,10 +1,25 @@ + +.. _workflows_index: + Workflows available =================== +The primary workflow in MFA is forced alignment, where text is aligned to speech along with phones derived from a pronunciation dictionary and an acoustic model. There are, however, other workflows for transcribing speech using speech-to-text functionality in Kaldi, pronunciation dictionary creation using Pynini, and some basic corpus creation utilities like VAD-based segmentation. Additionally, acoustic models, G2P models, and language models can be trained from your own data (and then used in alignment and other workflows). + +.. warning:: + + Speech-to-text functionality is pretty basic, and the model architecture used in MFA is older GMM-HMM and NGram models, so using something like :xref:`coqui` or Kaldi's ``nnet`` functionality will likely yield better quality transcriptions. + +.. hint:: + + See :ref:`pretrained_models` for details about commands to inspect, download, and save various pretrained MFA models. + .. toctree:: :maxdepth: 2 - aligning/index + alignment + adapt_acoustic_model train_acoustic_model - g2p/index + dictionary_generating + g2p_train corpus_creation diff --git a/docs/source/user_guide/workflows/train_acoustic_model.rst b/docs/source/user_guide/workflows/train_acoustic_model.rst index f04e82da..1c1afa5e 100644 --- a/docs/source/user_guide/workflows/train_acoustic_model.rst +++ b/docs/source/user_guide/workflows/train_acoustic_model.rst @@ -1,14 +1,27 @@ .. _train_acoustic_model: -***************************** -Training a new acoustic model -***************************** +Train a new acoustic model ``(mfa train)`` +========================================== + +You can train new :term:`acoustic models` from scratch using MFA, and export the final alignments as :term:`TextGrids` at the end. You don't need a ton of data to generate decent alignments (see `the blog post comparing alignments trained on various corpus sizes `_). At the end of the day, it comes down to trial and error, so I would recommend trying different workflows of pretrained models vs training your own or adapting a model to your data to see what performs best. Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: train + +Configuration reference +----------------------- + +- :ref:`configuration_acoustic_modeling` + +API reference +------------- + +- :ref:`acoustic_modeling_api` + + - :ref:`acoustic_model_training_api` diff --git a/docs/source/user_guide/workflows/train_ivector.rst b/docs/source/user_guide/workflows/train_ivector.rst index e15a017e..fa0dee09 100644 --- a/docs/source/user_guide/workflows/train_ivector.rst +++ b/docs/source/user_guide/workflows/train_ivector.rst @@ -1,15 +1,27 @@ .. _train_ivector: -***************************** -Training an ivector extractor -***************************** +Train an ivector extractor ``(mfa train_ivector)`` +================================================== -The Montreal Forced Aligner can train ivector extractors using an acoustic model for generating alignments. As part -of this training process, a classifier is built in that can be used as part of :ref:`classify_speakers`. +The Montreal Forced Aligner can train :term:`ivector extractors` using an acoustic model for generating alignments. As part of this training process, a classifier is built in that can be used as part of :ref:`classify_speakers`. + +.. warning:: + + This feature is not fully implemented, and is still under construction. Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: train_ivector + +Configuration reference +----------------------- + +- :ref:`configuration_ivector` + +API reference +------------- + +- :ref:`ivector_api` diff --git a/docs/source/user_guide/workflows/training_dictionary.rst b/docs/source/user_guide/workflows/training_dictionary.rst index fb103cde..5f191114 100644 --- a/docs/source/user_guide/workflows/training_dictionary.rst +++ b/docs/source/user_guide/workflows/training_dictionary.rst @@ -1,10 +1,9 @@ .. _training_dictionary: -************************************ -Modeling pronunciation probabilities -************************************ +Add probabilities to a dictionary ``(mfa train_dictionary)`` +============================================================ -MFA includes a utility command for training pronunciation probabilities of a dictionary given a corpus for alignment. +MFA includes a utility command for training :term:`pronunciation probabilities` of a dictionary given a corpus for alignment. The resulting dictionary can then be used as a dictionary for alignment or transcription. @@ -12,6 +11,6 @@ The resulting dictionary can then be used as a dictionary for alignment or trans Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: train_dictionary diff --git a/docs/source/user_guide/workflows/training_lm.rst b/docs/source/user_guide/workflows/training_lm.rst index 2d9b42fa..22d7959a 100644 --- a/docs/source/user_guide/workflows/training_lm.rst +++ b/docs/source/user_guide/workflows/training_lm.rst @@ -1,10 +1,9 @@ .. _training_lm: -************************ -Training language models -************************ +Train a new language model ``(mfa train_lm)`` +============================================== -MFA has a utility function for training ARPA-format ngram language models, as well as merging with a pre-existing model. +MFA has a utility function for training ARPA-format ngram :term:`language models`, as well as merging with a pre-existing model. .. warning:: @@ -14,6 +13,16 @@ MFA has a utility function for training ARPA-format ngram language models, as we Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: train_lm + +Configuration reference +----------------------- + +- :ref:`configuration_language_modeling` + +API reference +------------- + +- :ref:`language_modeling_api` diff --git a/docs/source/user_guide/workflows/transcribing.rst b/docs/source/user_guide/workflows/transcribing.rst index 9b40322c..60c1c506 100644 --- a/docs/source/user_guide/workflows/transcribing.rst +++ b/docs/source/user_guide/workflows/transcribing.rst @@ -1,18 +1,26 @@ -.. _`Coqui`: https://coqui.ai/ .. _transcribing: -********************************* -Transcribe audio (Speech-to-text) -********************************* +Transcribe audio files ``(mfa transcribe)`` +=========================================== .. warning:: - The technology that MFA uses is several years out of date, and as such if you have other options available such as :xref:`coqui` or other production systems for speech-to-text, we recommend using those. The transcription capabilities are more here for completeness. + The technology that MFA uses is several years out of date, and as such if you have other options available such as :xref:`coqui` or other production systems for :abbr:`STT (Speech to Text)`, we recommend using those. The transcription capabilities are more here for completeness. Command reference ----------------- -.. autoprogram:: montreal_forced_aligner.command_line.mfa:parser +.. autoprogram:: montreal_forced_aligner.command_line.mfa:create_parser() :prog: mfa :start_command: transcribe + +Configuration reference +----------------------- + +- :ref:`transcribe_config` + +API reference +------------- + +- :ref:`transcription_api` diff --git a/environment.yml b/environment.yml index 38e98321..03fb37d5 100644 --- a/environment.yml +++ b/environment.yml @@ -2,7 +2,7 @@ name: mfa channels: - conda-forge dependencies: - - python>=3.8 + - python>=3.9 - numpy - librosa - tqdm @@ -15,6 +15,4 @@ dependencies: - baumwelch - ngram - pynini - - pip - - pip: - - praatio >= 5.0 + - praatio diff --git a/environment_win.yml b/environment_win.yml index a1dd9a47..1e827581 100644 --- a/environment_win.yml +++ b/environment_win.yml @@ -2,7 +2,7 @@ name: montreal-forced-aligner channels: - conda-forge dependencies: - - python>=3.8 + - python>=3.9 - numpy - librosa - tqdm @@ -11,6 +11,4 @@ dependencies: - pyyaml - kaldi - sox - - pip - - pip: - - praatio>=5.0 + - praatio diff --git a/montreal_forced_aligner/__init__.py b/montreal_forced_aligner/__init__.py index 2d84b423..a627c060 100644 --- a/montreal_forced_aligner/__init__.py +++ b/montreal_forced_aligner/__init__.py @@ -1,31 +1,36 @@ """Montreal Forced Aligner is a package for aligning speech corpora through the use of acoustic models and dictionaries using Kaldi functionality.""" -import montreal_forced_aligner.aligner as aligner # noqa -import montreal_forced_aligner.command_line as command_line # noqa -import montreal_forced_aligner.config as config # noqa -import montreal_forced_aligner.corpus as corpus # noqa -import montreal_forced_aligner.dictionary as dictionary # noqa -import montreal_forced_aligner.exceptions as exceptions # noqa -import montreal_forced_aligner.g2p as g2p # noqa -import montreal_forced_aligner.helper as helper # noqa -import montreal_forced_aligner.models as models # noqa -import montreal_forced_aligner.multiprocessing as multiprocessing # noqa -import montreal_forced_aligner.textgrid as textgrid # noqa -import montreal_forced_aligner.utils as utils # noqa +import montreal_forced_aligner.acoustic_modeling as acoustic_modeling +import montreal_forced_aligner.alignment as alignment +import montreal_forced_aligner.command_line as command_line +import montreal_forced_aligner.corpus as corpus +import montreal_forced_aligner.dictionary as dictionary +import montreal_forced_aligner.exceptions as exceptions +import montreal_forced_aligner.g2p as g2p +import montreal_forced_aligner.helper as helper +import montreal_forced_aligner.ivector as ivector +import montreal_forced_aligner.language_modeling as language_modeling +import montreal_forced_aligner.models as models +import montreal_forced_aligner.textgrid as textgrid +import montreal_forced_aligner.transcription as transcription +import montreal_forced_aligner.utils as utils __all__ = [ "abc", "data", - "aligner", + "acoustic_modeling", + "alignment", "command_line", "config", "corpus", "dictionary", "exceptions", "g2p", + "ivector", + "language_modeling", "helper", "models", - "multiprocessing", + "transcription", "textgrid", "utils", ] diff --git a/montreal_forced_aligner/__main__.py b/montreal_forced_aligner/__main__.py new file mode 100644 index 00000000..0481abd2 --- /dev/null +++ b/montreal_forced_aligner/__main__.py @@ -0,0 +1,3 @@ +from .command_line.mfa import main + +main() diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index 88161c84..db3474a6 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -5,283 +5,729 @@ from __future__ import annotations +import logging +import os +import shutil +import sys +import time from abc import ABC, ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Type, Union, get_type_hints -if TYPE_CHECKING: - from .config.align_config import AlignConfig - from .config.dictionary_config import DictionaryConfig - from .config.transcribe_config import TranscribeConfig - from .corpus.base import Corpus - from .dictionary.multispeaker import MultispeakerDictionary - from .models import AcousticModel, DictionaryModel, LanguageModel +import yaml +if TYPE_CHECKING: + from argparse import Namespace __all__ = [ "MfaModel", "MfaWorker", - "Dictionary", + "TopLevelMfaWorker", "MetaDict", - "AcousticModelWorker", - "IvectorExtractor", - "Trainer", - "Transcriber", - "Aligner", + "MappingType", + "CtmErrorDict", + "FileExporterMixin", + "ModelExporterMixin", + "TemporaryDirectoryMixin", + "AdapterMixin", + "TrainerMixin", "DictionaryEntryType", "ReversedMappingType", "Labels", + "WordsType", + "OneToOneMappingType", + "OneToManyMappingType", + "CorpusMappingType", + "ScpType", ] # Configuration types -MetaDict = Dict[str, Any] -Labels = List[Any] -CtmErrorDict = Dict[Tuple[str, int], str] +MetaDict = dict[str, Any] +Labels: list[Any] +CtmErrorDict: dict[tuple[str, int], str] # Dictionary types -DictionaryEntryType = List[Dict[str, Union[Tuple[str], float, None, int]]] -ReversedMappingType = Dict[int, str] -WordsType = Dict[str, DictionaryEntryType] -MappingType = Dict[str, int] -MultiSpeakerMappingType = Dict[str, str] -IpaType = Optional[List[str]] -PunctuationType = Optional[str] +DictionaryEntryType: list[dict[str, Union[tuple[str], float, None, int]]] +ReversedMappingType: dict[int, str] +WordsType: dict[str, DictionaryEntryType] +MappingType: dict[str, int] # Corpus types -SegmentsType = Dict[str, Dict[str, Union[str, float, int]]] -OneToOneMappingType = Dict[str, str] -OneToManyMappingType = Dict[str, List[str]] +OneToOneMappingType: dict[str, str] +OneToManyMappingType: dict[str, list[str]] -CorpusMappingType = Union[OneToOneMappingType, OneToManyMappingType] -ScpType = Union[List[Tuple[str, str]], List[Tuple[str, List[Any]]]] -CorpusGroupedOneToOne = List[List[Tuple[str, str]]] -CorpusGroupedOneToMany = List[List[Tuple[str, List[Any]]]] -CorpusGroupedType = Union[CorpusGroupedOneToMany, CorpusGroupedOneToOne] +CorpusMappingType: Union[OneToOneMappingType, OneToManyMappingType] +ScpType: Union[list[tuple[str, str]], list[tuple[str, list[Any]]]] -class MfaWorker(metaclass=ABCMeta): - """Abstract class for MFA workers""" +class TemporaryDirectoryMixin(metaclass=ABCMeta): + """ + Abstract mixin class for MFA temporary directories + + Parameters + ---------- + temporary_directory: str, optional + Path to store temporary files + """ + + def __init__( + self, + temporary_directory: str = None, + **kwargs, + ): + super().__init__(**kwargs) + if not temporary_directory: + from .config import get_temporary_directory - def __init__(self, corpus: Corpus): - self.corpus = corpus + temporary_directory = get_temporary_directory() + self.temporary_directory = temporary_directory @property @abstractmethod - def working_directory(self) -> str: - """Current directory""" + def identifier(self) -> str: + """Identifier to use in creating the temporary directory""" ... @property - def data_directory(self) -> str: - """Corpus data directory""" - return self._data_directory - - @data_directory.setter - def data_directory(self, val: str) -> None: - self._data_directory = val + @abstractmethod + def data_source_identifier(self) -> str: + """Identifier for the data source (generally the corpus being used)""" + ... @property - def uses_voiced(self) -> bool: - """Flag for using voiced features""" - return self._uses_voiced + @abstractmethod + def output_directory(self) -> str: + """Root temporary directory""" + ... - @uses_voiced.setter - def uses_voiced(self, val: bool) -> None: - self._uses_voiced = val + @property + def corpus_output_directory(self) -> str: + """Temporary directory containing all corpus information""" + return os.path.join(self.output_directory, f"{self.data_source_identifier}") @property - def uses_cmvn(self) -> bool: - """Flag for using CMVN""" - return self._uses_cmvn + def dictionary_output_directory(self) -> str: + """Temporary directory containing all dictionary information""" + return os.path.join(self.output_directory, "dictionary") - @uses_cmvn.setter - def uses_cmvn(self, val: bool) -> None: - self._uses_cmvn = val - @property - def uses_splices(self) -> bool: - """Flag for using spliced features""" - return self._uses_splices +class MfaWorker(metaclass=ABCMeta): + """ + Abstract class for MFA workers - @uses_splices.setter - def uses_splices(self, val: bool) -> None: - self._uses_splices = val + Parameters + ---------- + use_mp: bool + Flag to run in multiprocessing mode, defaults to True + debug: bool + Flag to run in debug mode, defaults to False + verbose: bool + Flag to run in verbose mode, defaults to False - @property - def speaker_independent(self) -> bool: - """Flag for speaker independent features""" - return self._speaker_independent + Attributes + ---------- + dirty: bool + Flag for whether an error was encountered in processing + """ - @speaker_independent.setter - def speaker_independent(self, val: bool) -> None: - self._speaker_independent = val + def __init__( + self, + use_mp: bool = True, + debug: bool = False, + verbose: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.debug = debug + self.verbose = verbose + self.use_mp = use_mp + self.dirty = False + + @abstractmethod + def log_debug(self, message: str) -> None: + """Abstract method for logging debug messages""" + ... + + @abstractmethod + def log_info(self, message: str) -> None: + """Abstract method for logging info messages""" + ... + + @abstractmethod + def log_warning(self, message: str) -> None: + """Abstract method for logging warning messages""" + ... + + @abstractmethod + def log_error(self, message: str) -> None: + """Abstract method for logging error messages""" + ... + + @classmethod + def extract_relevant_parameters(cls, config: MetaDict) -> MetaDict: + """ + Filter a configuration dictionary to just the relevant parameters for the current worker + + Parameters + ---------- + config: dict[str, Any] + Configuration dictionary + + Returns + ------- + dict[str, Any] + Filtered configuration dictionary + """ + return {k: v for k, v in config.items() if k in cls.get_configuration_parameters()} + + @classmethod + def get_configuration_parameters(cls) -> dict[str, Type]: + """ + Get the types of parameters available to be configured + + Returns + ------- + dict[str, Type] + Dictionary of parameter names and their types + """ + configuration_params = {} + for t, ty in get_type_hints(cls.__init__).items(): + configuration_params[t] = ty + try: + if ty.__origin__ == Union: + configuration_params[t] = ty.__args__[0] + except AttributeError: + pass + + for c in cls.mro(): + try: + for t, ty in get_type_hints(c.__init__).items(): + configuration_params[t] = ty + try: + if ty.__origin__ == Union: + configuration_params[t] = ty.__args__[0] + except AttributeError: + pass + except AttributeError: + pass + return configuration_params + + @property + def configuration(self) -> MetaDict: + """Configuration parameters""" + return { + "debug": self.debug, + "verbose": self.verbose, + "use_mp": self.use_mp, + "dirty": self.dirty, + } @property @abstractmethod - def working_log_directory(self) -> str: - """Current log directory""" + def working_directory(self) -> str: + """Current working directory""" ... @property - def use_mp(self) -> bool: - """Flag for using multiprocessing""" - return self._use_mp + def working_log_directory(self) -> str: + """Current working log directory""" + return os.path.join(self.working_directory, "log") - @use_mp.setter - def use_mp(self, val: bool) -> None: - self._use_mp = val + @property + @abstractmethod + def data_directory(self) -> str: + """Data directory""" + ... -class AcousticModelWorker(MfaWorker): +class TopLevelMfaWorker(MfaWorker, TemporaryDirectoryMixin, metaclass=ABCMeta): """ - Abstract class for MFA classes that use acoustic models + Abstract mixin for top-level workers in MFA. This class holds properties about the larger workflow run. Parameters ---------- - dictionary: MultispeakerDictionary - Dictionary for the worker docstring + num_jobs: int + Number of jobs and processes to uses + clean: bool + Flag for whether to remove any old files in the work directory """ - def __init__(self, corpus: Corpus, dictionary: MultispeakerDictionary): - super().__init__(corpus) - self.dictionary: MultispeakerDictionary = dictionary + def __init__( + self, + num_jobs: int = 3, + clean: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.num_jobs = num_jobs + self.clean = clean + self.initialized = False + self.start_time = time.time() + self.setup_logger() + + def __del__(self): + """Ensure that loggers are cleaned up on delete""" + handlers = self.logger.handlers[:] + for handler in handlers: + handler.close() + self.logger.removeHandler(handler) + @abstractmethod + def setup(self) -> None: + """Abstract method for setting up a top-level worker""" + ... -class Trainer(AcousticModelWorker): - """ - Abstract class for MFA trainers + @property + def working_directory(self) -> str: + """Alias for a folder that contains worker information, separate from the data directory""" + return self.workflow_directory - Attributes - ---------- - iteration: int - Current iteration - """ + @classmethod + def parse_args(cls, args: Optional[Namespace], unknown_args: Optional[list[str]]) -> MetaDict: + """ + Class method for parsing configuration parameters from command line arguments + + Parameters + ---------- + args: :class:`~argparse.Namespace` + Arguments parsed by argparse + unknown_args: list[str] + Optional list of arguments that were not parsed by argparse + + Returns + ------- + dict[str, Any] + Dictionary of specified configuration parameters + """ + param_types = cls.get_configuration_parameters() + params = {} + unknown_dict = {} + if unknown_args: + for i, a in enumerate(unknown_args): + if not a.startswith("--"): + continue + name = a.replace("--", "") + if name not in param_types: + continue + if i == len(unknown_args) - 1 or unknown_args[i + 1].startswith("--"): + val = True + else: + val = unknown_args[i + 1] + unknown_dict[name] = val + for name, param_type in param_types.items(): + if name.endswith("_directory") or name.endswith("_path"): + continue + if args is not None and hasattr(args, name): + params[name] = param_type(getattr(args, name)) + elif name in unknown_dict: + params[name] = param_type(unknown_dict[name]) + if param_type == bool: + if unknown_dict[name].lower() == "false": + params[name] = False + return params - def __init__(self, corpus: Corpus, dictionary: MultispeakerDictionary): - super(Trainer, self).__init__(corpus, dictionary) - self.iteration = 0 + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse configuration parameters from a config file and command line arguments + + Parameters + ---------- + config_path: str, optional + Path to yaml configuration file + args: :class:`~argparse.Namespace`, optional + Arguments parsed by argparse + unknown_args: list[str], optional + List of unknown arguments from argparse + + Returns + ------- + dict[str, Any] + Dictionary of specified configuration parameters + """ + global_params = {} + if config_path and os.path.exists(config_path): + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + for k, v in data.items(): + global_params[k] = v + global_params.update(cls.parse_args(args, unknown_args)) + return global_params @property @abstractmethod - def meta(self) -> MetaDict: - """Training configuration parameters""" + def workflow_identifier(self) -> str: + """Identifier of the worker's workflow""" ... - @abstractmethod - def train(self) -> None: - """Perform training""" - ... + @property + def worker_config_path(self): + """Path to worker's configuration in the working directory""" + return os.path.join(self.output_directory, f"{self.workflow_identifier}.yaml") + + def cleanup(self) -> None: + """ + Clean up loggers and output final message for top-level workers + """ + try: + if self.dirty: + self.logger.error("There was an error in the run, please see the log.") + else: + self.logger.info(f"Done! Everything took {time.time() - self.start_time} seconds") + handlers = self.logger.handlers[:] + for handler in handlers: + handler.close() + self.logger.removeHandler(handler) + self.save_worker_config() + except (NameError, ValueError): # already cleaned up + pass + + def save_worker_config(self): + """Export worker configuration to its working directory""" + with open(self.worker_config_path, "w") as f: + yaml.dump(self.configuration, f) + + def _validate_previous_configuration(self, conf: MetaDict) -> bool: + """ + Validate the current configuration against a previous configuration + + Parameters + ---------- + conf: dict[str, Any] + Previous run's configuration + + Returns + ------- + bool + Flag for whether the current run is compatible with the previous one + """ + from montreal_forced_aligner.utils import get_mfa_version + + clean = True + current_version = get_mfa_version() + if conf["dirty"]: + self.logger.debug("Previous run ended in an error (maybe ctrl-c?)") + clean = False + if "type" in conf: + command = conf["type"] + elif "command" in conf: + command = conf["command"] + else: + command = self.workflow_identifier + if command != self.workflow_identifier: + self.logger.debug( + f"Previous run was a different subcommand than {self.workflow_identifier} (was {command})" + ) + clean = False + if conf.get("version", current_version) != current_version: + self.logger.debug( + f"Previous run was on {conf['version']} version (new run: {current_version})" + ) + clean = False + for key in [ + "corpus_directory", + "dictionary_path", + "acoustic_model_path", + "g2p_model_path", + "language_model_path", + ]: + if conf.get(key, None) != getattr(self, key, None): + self.logger.debug( + f"Previous run used a different {key.replace('_', ' ')} than {getattr(self, key, None)} (was {conf.get(key, None)})" + ) + clean = False + return clean + + def check_previous_run(self) -> bool: + """ + Check whether a previous run has any conflicting settings with the current run. + + Returns + ------- + bool + Flag for whether the current run is compatible with the previous one + """ + if not os.path.exists(self.worker_config_path): + return True + with open(self.worker_config_path, "r") as f: + conf = yaml.load(f, Loader=yaml.SafeLoader) + clean = self._validate_previous_configuration(conf) + if not clean: + self.logger.warning( + "The previous run had a different configuration than the current, which may cause issues." + " Please see the log for details or use --clean flag if issues are encountered." + ) + return clean + @property + def identifier(self) -> str: + """Combined identifier of the data source and workflow""" + return f"{self.data_source_identifier}_{self.workflow_identifier}" -class Aligner(AcousticModelWorker): - """Abstract class for MFA aligners""" + @property + def output_directory(self) -> str: + """Root temporary directory to store all of this worker's files""" + return os.path.join(self.temporary_directory, self.identifier) - def __init__( - self, corpus: Corpus, dictionary: MultispeakerDictionary, align_config: AlignConfig - ): - super().__init__(corpus, dictionary) - self.align_config = align_config + @property + def workflow_directory(self) -> str: + """Temporary directory to save work specific to the worker (i.e., not data)""" + return os.path.join(self.output_directory, self.workflow_identifier) - @abstractmethod - def align(self, subset: Optional[int] = None) -> None: - """Perform alignment""" - ... + @property + def log_file(self): + """Path to the worker's log file""" + return os.path.join(self.output_directory, f"{self.workflow_identifier}.log") + + def setup_logger(self): + """ + Construct a logger for a command line run + """ + from .utils import CustomFormatter, get_mfa_version + + if self.clean: + shutil.rmtree(self.output_directory, ignore_errors=True) + os.makedirs(self.workflow_directory, exist_ok=True) + if os.path.exists(self.log_file): + os.remove(self.log_file) + self.logger = logging.getLogger(self.workflow_identifier) + self.logger.setLevel(logging.DEBUG) + + file_handler = logging.FileHandler(self.log_file, encoding="utf8") + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + handler = logging.StreamHandler(sys.stdout) + if self.verbose: + handler.setLevel(logging.DEBUG) + else: + handler.setLevel(logging.INFO) + handler.setFormatter(CustomFormatter()) + self.logger.addHandler(handler) + self.logger.debug(f"Set up logger for MFA version: {get_mfa_version()}") + if self.clean: + self.logger.debug("Cleaned previous run") + + def log_debug(self, message: str) -> None: + """ + Log a debug message. This function is a wrapper around the :meth:`logging.Logger.debug` + + Parameters + ---------- + message: str + Debug message to log + """ + self.logger.debug(message) + + def log_info(self, message: str) -> None: + """ + Log an info message. This function is a wrapper around the :meth:`logging.Logger.info` + + Parameters + ---------- + message: str + Info message to log + """ + self.logger.info(message) + + def log_warning(self, message: str) -> None: + """ + Log a warning message. This function is a wrapper around the :meth:`logging.Logger.warning` + + Parameters + ---------- + message: str + Warning message to log + """ + self.logger.warning(message) + + def log_error(self, message: str) -> None: + """ + Log an error message. This function is a wrapper around the :meth:`logging.Logger.error` + + Parameters + ---------- + message: str + Error message to log + """ + self.logger.error(message) + + +class ModelExporterMixin(metaclass=ABCMeta): + """ + Abstract mixin class for exporting MFA models + + Parameters + ---------- + overwrite: bool + Flag for whether to overwrite the specified path if a file exists + """ + + def __init__(self, overwrite: bool = False, **kwargs): + self.overwrite = overwrite + super().__init__(**kwargs) @property @abstractmethod - def model_path(self) -> str: - """Acoustic model file path""" + def meta(self) -> MetaDict: + """Training configuration parameters""" ... - @property @abstractmethod - def alignment_model_path(self) -> str: - """Acoustic model file path for speaker-independent alignment""" + def export_model(self, output_model_path: str) -> None: + """ + Abstract method to export an MFA model + + Parameters + ---------- + output_model_path: str + Path to export model + """ ... -class Transcriber(AcousticModelWorker): - """Abstract class for MFA transcribers""" +class FileExporterMixin(metaclass=ABCMeta): + """ + Abstract mixin class for exporting TextGrid and text files - def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - acoustic_model: AcousticModel, - language_model: LanguageModel, - transcribe_config: TranscribeConfig, - ): - super().__init__(corpus, dictionary) - self.acoustic_model = acoustic_model - self.language_model = language_model - self.transcribe_config = transcribe_config + Parameters + ---------- + overwrite: bool + Flag for whether to overwrite files if they already exist - @abstractmethod - def transcribe(self) -> None: - """Perform transcription""" - ... + """ + + def __init__(self, overwrite: bool = False, cleanup_textgrids: bool = True, **kwargs): + self.overwrite = overwrite + self.cleanup_textgrids = cleanup_textgrids + super().__init__(**kwargs) @property + def backup_output_directory(self) -> Optional[str]: + """Path to store files if overwriting is not allowed""" + if self.overwrite: + return None + return os.path.join(self.working_directory, "backup") + @abstractmethod - def model_path(self) -> str: - """Acoustic model file path""" + def export_files(self, output_directory: str) -> None: + """ + Export files to an output directory + + Parameters + ---------- + output_directory: str + Directory to export to + """ ... -class IvectorExtractor(AcousticModelWorker): - """Abstract class for MFA ivector extractors""" +class TrainerMixin(ModelExporterMixin): + """ + Abstract mixin class for MFA trainers - @abstractmethod - def extract_ivectors(self) -> None: - """Extract ivectors""" - ... + Parameters + ---------- + num_iterations: int + Number of training iterations + + Attributes + ---------- + iteration: int + Current iteration + """ + + def __init__(self, num_iterations: int = 40, **kwargs): + super().__init__(**kwargs) + self.iteration: int = 0 + self.num_iterations = num_iterations - @property @abstractmethod - def model_path(self) -> str: - """Acoustic model file path""" + def initialize_training(self) -> None: + """Initialize training""" ... - @property @abstractmethod - def ivector_options(self) -> MetaDict: - """Ivector parameters""" + def train(self) -> None: + """Perform training""" ... - @property @abstractmethod - def dubm_path(self) -> str: - """DUBM model file path""" + def train_iteration(self) -> None: + """Run one training iteration""" ... - @property @abstractmethod - def ie_path(self) -> str: - """Ivector extractor model file path""" + def finalize_training(self) -> None: + """Finalize training""" ... -class Dictionary(ABC): - """Abstract class for pronunciation dictionaries""" +class AdapterMixin(ModelExporterMixin): + """ + Abstract class for MFA model adaptation + """ - def __init__(self, dictionary_model: DictionaryModel, config: DictionaryConfig): - self.name = dictionary_model.name - self.dictionary_model = dictionary_model - self.config = config + @abstractmethod + def adapt(self) -> None: + """Perform adaptation""" + ... class MfaModel(ABC): """Abstract class for MFA models""" - @property - @abstractmethod - def extensions(self) -> Collection: - """File extensions for the model""" - ... + extensions: list[str] + model_type = "base_model" - @extensions.setter - @abstractmethod - def extensions(self, val: Collection) -> None: - ... + @classmethod + def pretrained_directory(cls) -> str: + from .config import get_temporary_directory + + return os.path.join(get_temporary_directory(), "pretrained_models", cls.model_type) + + @classmethod + def get_available_models(cls) -> list[str]: + """ + Get a list of available models for a given model type + + Returns + ------- + list[str] + List of model names + """ + if not os.path.exists(cls.pretrained_directory()): + return [] + available = [] + for f in os.listdir(cls.pretrained_directory()): + if cls.valid_extension(f): + available.append(os.path.splitext(f)[0]) + return available + + @classmethod + def get_pretrained_path(cls, name: str, enforce_existence: bool = True) -> str: + """ + Generate a path to a pretrained model based on its name and model type + + Parameters + ---------- + name: str + Name of model + enforce_existence: bool + Flag to return None if the path doesn't exist, defaults to True + + Returns + ------- + str + Path to model + """ + return cls.generate_path(cls.pretrained_directory(), name, enforce_existence) @classmethod @abstractmethod @@ -296,7 +742,7 @@ def generate_path(cls, root: str, name: str, enforce_existence: bool = True) -> ... @abstractmethod - def pretty_print(self): + def pretty_print(self) -> None: """Print the model's meta data""" ... @@ -307,6 +753,5 @@ def meta(self) -> MetaDict: ... @abstractmethod - def add_meta_file(self, trainer: Trainer) -> None: + def add_meta_file(self, trainer: TrainerMixin) -> None: """Add meta data to the model""" - ... diff --git a/montreal_forced_aligner/acoustic_modeling/__init__.py b/montreal_forced_aligner/acoustic_modeling/__init__.py new file mode 100644 index 00000000..c7b83ea8 --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/__init__.py @@ -0,0 +1,27 @@ +""" +Training acoustic models +======================== + + +""" +from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin # noqa +from montreal_forced_aligner.acoustic_modeling.lda import LdaTrainer # noqa +from montreal_forced_aligner.acoustic_modeling.monophone import MonophoneTrainer # noqa +from montreal_forced_aligner.acoustic_modeling.sat import SatTrainer # noqa +from montreal_forced_aligner.acoustic_modeling.trainer import TrainableAligner # noqa +from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer # noqa + +__all__ = [ + "AcousticModelTrainingMixin", + "LdaTrainer", + "MonophoneTrainer", + "SatTrainer", + "TriphoneTrainer", + "TrainableAligner", + "base", + "lda", + "monophone", + "sat", + "triphone", + "trainer", +] diff --git a/montreal_forced_aligner/acoustic_modeling/base.py b/montreal_forced_aligner/acoustic_modeling/base.py new file mode 100644 index 00000000..ca7bf9ad --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/base.py @@ -0,0 +1,931 @@ +"""Class definition for BaseTrainer""" +from __future__ import annotations + +import logging +import os +import re +import shutil +import statistics +import subprocess +import time +from abc import abstractmethod +from typing import TYPE_CHECKING, NamedTuple, Optional + +from tqdm import tqdm + +from montreal_forced_aligner.abc import MfaWorker, ModelExporterMixin, TrainerMixin +from montreal_forced_aligner.alignment.base import AlignMixin +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusPronunciationMixin +from montreal_forced_aligner.corpus.features import FeatureConfigMixin +from montreal_forced_aligner.exceptions import KaldiProcessingError +from montreal_forced_aligner.helper import align_phones +from montreal_forced_aligner.models import AcousticModel +from montreal_forced_aligner.textgrid import process_ctm_line +from montreal_forced_aligner.utils import ( + log_kaldi_errors, + parse_logs, + run_mp, + run_non_mp, + thirdparty_binary, +) + +if TYPE_CHECKING: + from montreal_forced_aligner.abc import MetaDict + from montreal_forced_aligner.corpus.multiprocessing import Job + from montreal_forced_aligner.textgrid import CtmInterval + + +__all__ = ["AcousticModelTrainingMixin"] + + +class AlignmentImprovementArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.base.compute_alignment_improvement_func`""" + + log_path: str + dictionaries: list[str] + model_path: str + text_int_paths: dict[str, str] + word_boundary_paths: dict[str, str] + ali_paths: dict[str, str] + frame_shift: int + reversed_phone_mappings: dict[str, dict[int, str]] + positions: dict[str, list[str]] + phone_ctm_paths: dict[str, str] + + +class AccStatsArguments(NamedTuple): + """ + Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.base.acc_stats_func` + """ + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ali_paths: dict[str, str] + acc_paths: dict[str, str] + model_path: str + + +def acc_stats_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ali_paths: dict[str, str], + acc_paths: dict[str, str], + model_path: str, +) -> None: + """ + Multiprocessing function for accumulating stats in GMM training. + + See Also + -------- + :meth:`.AcousticModelTrainingMixin.acc_stats` + Main function that calls this function in parallel + :meth:`.AcousticModelTrainingMixin.acc_stats_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-acc-stats-ali` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + acc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + model_path: str + Path to the acoustic model file + """ + model_path = model_path + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + acc_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-acc-stats-ali"), + model_path, + feature_strings[dict_name], + f"ark,s,cs:{ali_paths[dict_name]}", + acc_paths[dict_name], + ], + stderr=log_file, + env=os.environ, + ) + acc_proc.communicate() + + +def compute_alignment_improvement_func( + log_path: str, + dictionaries: list[str], + model_path: str, + text_int_paths: dict[str, str], + word_boundary_paths: dict[str, str], + ali_paths: dict[str, str], + frame_shift: int, + reversed_phone_mappings: dict[str, dict[int, str]], + positions: dict[str, list[str]], + phone_ctm_paths: dict[str, str], +) -> None: + """ + Multiprocessing function for computing alignment improvement over training + + See Also + -------- + :meth:`.AcousticModelTrainingMixin.compute_alignment_improvement` + Main function that calls this function in parallel + :meth:`.AcousticModelTrainingMixin.alignment_improvement_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`linear-to-nbest` + Relevant Kaldi binary + :kaldi_src:`lattice-determinize-pruned` + Relevant Kaldi binary + :kaldi_src:`lattice-align-words` + Relevant Kaldi binary + :kaldi_src:`lattice-to-phone-lattice` + Relevant Kaldi binary + :kaldi_src:`nbest-to-ctm` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + model_path: str + Path to the acoustic model file + text_int_paths: dict[str, str] + Dictionary of text int files per dictionary name + word_boundary_paths: dict[str, str] + Dictionary of word boundary files per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + frame_shift: int + Frame shift of feature generation, in ms + reversed_phone_mappings: dict[str, dict[int, str]] + Mapping of phone IDs to phone labels per dictionary name + positions: dict[str, list[str]] + Positions per dictionary name + phone_ctm_paths: dict[str, str] + Dictionary of phone ctm files per dictionary name + """ + try: + + frame_shift = frame_shift / 1000 + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + text_int_path = text_int_paths[dict_name] + ali_path = ali_paths[dict_name] + phone_ctm_path = phone_ctm_paths[dict_name] + word_boundary_path = word_boundary_paths[dict_name] + if os.path.exists(phone_ctm_path): + continue + + lin_proc = subprocess.Popen( + [ + thirdparty_binary("linear-to-nbest"), + f"ark:{ali_path}", + f"ark:{text_int_path}", + "", + "", + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + det_proc = subprocess.Popen( + [thirdparty_binary("lattice-determinize-pruned"), "ark:-", "ark:-"], + stdin=lin_proc.stdout, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + align_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-align-words"), + word_boundary_path, + model_path, + "ark:-", + "ark:-", + ], + stdin=det_proc.stdout, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + phone_proc = subprocess.Popen( + [thirdparty_binary("lattice-to-phone-lattice"), model_path, "ark:-", "ark:-"], + stdin=align_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + nbest_proc = subprocess.Popen( + [ + thirdparty_binary("nbest-to-ctm"), + f"--frame-shift={frame_shift}", + "ark:-", + phone_ctm_path, + ], + stdin=phone_proc.stdout, + stderr=log_file, + env=os.environ, + ) + nbest_proc.communicate() + mapping = reversed_phone_mappings[dict_name] + actual_lines = [] + with open(phone_ctm_path, "r", encoding="utf8") as f: + for line in f: + line = line.strip() + if line == "": + continue + line = line.split(" ") + utt = line[0] + begin = float(line[2]) + duration = float(line[3]) + end = begin + duration + label = line[4] + try: + label = mapping[int(label)] + except KeyError: + pass + for p in positions[dict_name]: + if label.endswith(p): + label = label[: -1 * len(p)] + actual_lines.append([utt, begin, end, label]) + with open(phone_ctm_path, "w", encoding="utf8") as f: + for line in actual_lines: + f.write(f"{' '.join(map(str, line))}\n") + except Exception as e: + raise (Exception(str(e))) + + +def compare_alignments( + alignments_one: dict[str, list[CtmInterval]], + alignments_two: dict[str, list[CtmInterval]], + silence_phones: set[str], +) -> tuple[Optional[int], Optional[float]]: + """ + Compares two sets of alignments for difference + + See Also + -------- + :meth:`.AcousticModelTrainingMixin.compute_alignment_improvement` + Main function that calls this function + + Parameters + ---------- + alignments_one: dict[str, list[tuple[float, float, str]]] + First set of alignments + alignments_two: dict[str, list[tuple[float, float, str]]] + Second set of alignments + frame_shift: int + Frame shift in feature generation, in ms + + Returns + ------- + Optional[int] + Difference in number of aligned files + Optional[float] + Mean boundary difference between the two alignments + """ + utterances_aligned_diff = len(alignments_two) - len(alignments_one) + utts_one = set(alignments_one.keys()) + utts_two = set(alignments_two.keys()) + common_utts = utts_one.intersection(utts_two) + differences = [] + for u in common_utts: + one_alignment = alignments_one[u] + two_alignment = alignments_two[u] + avg_overlap_diff, num_insertions, num_deletions = align_phones( + one_alignment, two_alignment, silence_phones + ) + if avg_overlap_diff is None: + return None, None + differences.append(avg_overlap_diff) + if differences: + mean_difference = statistics.mean(differences) + else: + mean_difference = None + return utterances_aligned_diff, mean_difference + + +class AcousticModelTrainingMixin( + AlignMixin, TrainerMixin, FeatureConfigMixin, MfaWorker, ModelExporterMixin +): + """ + Base trainer class for training acoustic models and ivector extractors + + Parameters + ---------- + identifier : str + Identifier for the trainer + worker: :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin` + Top-level worker + num_iterations : int + Number of iterations, defaults to 40 + subset : int + Number of utterances to use, defaults to 0 which will use the whole corpus + max_gaussians : int + Total number of gaussians, defaults to 1000 + boost_silence : float + Factor by which to boost silence during alignment, defaults to 1.25 + power : float + Exponent for number of gaussians according to occurrence counts, defaults to 0.25 + initial_gaussians : int + Initial number of gaussians, defaults to 0 + + See Also + -------- + :class:`~montreal_forced_aligner.alignment.mixins.AlignMixin` + For alignment parameters + :class:`~montreal_forced_aligner.abc.TrainerMixin` + For training parameters + :class:`~montreal_forced_aligner.corpus.features.FeatureConfigMixin` + For feature generation parameters + :class:`~montreal_forced_aligner.abc.MfaWorker` + For MFA processing parameters + :class:`~montreal_forced_aligner.abc.ModelExporterMixin` + For model export parameters + + Attributes + ---------- + realignment_iterations : list + List of iterations to perform alignment + """ + + architecture = "gmm-hmm" + + def __init__( + self, + identifier: str, + worker: AcousticCorpusPronunciationMixin, + num_iterations: int = 40, + subset: int = 0, + max_gaussians: int = 1000, + boost_silence: float = 1.25, + power: float = 0.25, + initial_gaussians: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.identifier = identifier + self.worker = worker + self.num_iterations = num_iterations + self.subset = subset + self.max_gaussians = max_gaussians + self.power = power + self.initial_gaussians = initial_gaussians + self.current_gaussians = initial_gaussians + self.boost_silence = boost_silence + self.training_complete = False + self.realignment_iterations = [] # Gets set later + + def acc_stats_arguments(self) -> list[AccStatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.base.acc_stats_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.base.AccStatsArguments`] + Arguments for processing + """ + feat_strings = self.worker.construct_feature_proc_strings() + return [ + AccStatsArguments( + os.path.join(self.working_directory, "log", f"acc.{self.iteration}.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.working_directory, str(self.iteration), "acc"), + self.model_path, + ) + for j in self.jobs + ] + + def alignment_improvement_arguments(self) -> list[AlignmentImprovementArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.base.compute_alignment_improvement_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.base.AlignmentImprovementArguments`] + Arguments for processing + """ + return [ + AlignmentImprovementArguments( + os.path.join(self.working_log_directory, f"alignment_analysis.{j.name}.log"), + j.current_dictionary_names, + self.model_path, + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + j.word_boundary_int_files(), + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + self.frame_shift, + j.reversed_phone_mappings(), + j.positions(), + j.construct_path_dictionary( + self.working_directory, f"phone.{self.iteration}", "ctm" + ), + ) + for j in self.jobs + ] + + @property + def previous_aligner(self): + """Previous aligner seeding training""" + return self.worker + + def log_debug(self, message: str) -> None: + """ + Log a debug message. This function is a wrapper around the worker's :meth:`logging.Logger.debug` + + Parameters + ---------- + message: str + Debug message to log + """ + self.worker.log_debug(message) + + def log_error(self, message: str) -> None: + """ + Log an info message. This function is a wrapper around the worker's :meth:`logging.Logger.info` + + Parameters + ---------- + message: str + Info message to log + """ + self.worker.log_error(message) + + def log_warning(self, message: str) -> None: + """ + Log a warning message. This function is a wrapper around the worker's :meth:`logging.Logger.warning` + + Parameters + ---------- + message: str + Warning message to log + """ + self.worker.log_warning(message) + + def log_info(self, message: str) -> None: + """ + Log an error message. This function is a wrapper around the worker's :meth:`logging.Logger.error` + + Parameters + ---------- + message: str + Error message to log + """ + self.worker.log_info(message) + + @property + def logger(self) -> logging.Logger: + """Top-level worker's logger""" + return self.worker.logger + + @property + def jobs(self) -> list[Job]: + """Top-level worker's job objects""" + return self.worker.jobs + + @property + def disambiguation_symbols_int_path(self) -> str: + """Path to the disambiguation int file""" + return self.worker.disambiguation_symbols_int_path + + def construct_feature_proc_strings( + self, speaker_independent: bool = False + ) -> list[dict[str, str]]: + """Top-level worker's feature strings""" + return self.worker.construct_feature_proc_strings(speaker_independent) + + def construct_base_feature_string(self, all_feats: bool = False) -> str: + """Top-level worker's base feature string""" + return self.worker.construct_base_feature_string(all_feats) + + @property + def data_directory(self) -> str: + """Get the current data directory based on subset""" + return self.worker.data_directory + + @property + def corpus_output_directory(self) -> str: + """Directory of the corpus""" + return self.worker.corpus_output_directory + + def initialize_training(self) -> None: + """Initialize training""" + self.compute_calculated_properties() + self.current_gaussians = 0 + begin = time.time() + dirty_path = os.path.join(self.working_directory, "dirty") + done_path = os.path.join(self.working_directory, "done") + if os.path.exists(dirty_path): # if there was an error, let's redo from scratch + shutil.rmtree(self.working_directory) + self.logger.info(f"Initializing training for {self.identifier}...") + if os.path.exists(done_path): + self.training_complete = True + return + os.makedirs(self.working_directory, exist_ok=True) + os.makedirs(self.working_log_directory, exist_ok=True) + if self.subset is not None and self.subset > self.worker.num_utterances: + self.logger.warning( + "Subset specified is larger than the dataset, " + "using full corpus for this training block." + ) + + try: + self._trainer_initialization() + parse_logs(self.working_log_directory) + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + self.iteration = 1 + self.worker.current_trainer = self + self.logger.info("Initialization complete!") + self.logger.debug( + f"Initialization for {self.identifier} took {time.time() - begin} seconds" + ) + + @abstractmethod + def _trainer_initialization(self) -> None: + """Descendant classes will override this for their own training initialization""" + ... + + def acoustic_model_training_params(self) -> MetaDict: + """Configuration parameters""" + return { + "subset": self.subset, + "num_iterations": self.num_iterations, + "max_gaussians": self.max_gaussians, + "power": self.power, + "initial_gaussians": self.initial_gaussians, + } + + @property + def working_directory(self) -> str: + """Training directory""" + return os.path.join(self.worker.output_directory, self.identifier) + + @property + def working_log_directory(self) -> str: + """Training log directory""" + return os.path.join(self.working_directory, "log") + + @property + def model_path(self) -> str: + """Current acoustic model path""" + if self.training_complete: + return self.next_model_path + return os.path.join(self.working_directory, f"{self.iteration}.mdl") + + @property + def alignment_model_path(self) -> str: + """Alignment model path""" + return self.model_path + + @property + def next_model_path(self): + """Next iteration's acoustic model path""" + if self.training_complete: + return os.path.join(self.working_directory, "final.mdl") + return os.path.join(self.working_directory, f"{self.iteration + 1}.mdl") + + @property + def next_occs_path(self): + """Next iteration's occs file path""" + if self.training_complete: + return os.path.join(self.working_directory, "final.occs") + return os.path.join(self.working_directory, f"{self.iteration + 1}.occs") + + @abstractmethod + def compute_calculated_properties(self) -> None: + """Compute any calculated properties such as alignment iterations""" + ... + + def increment_gaussians(self): + """Increment the current number of gaussians""" + self.current_gaussians += self.gaussian_increment + + def acc_stats(self): + """ + Multiprocessing function that accumulates stats for GMM training. + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.base.acc_stats_func` + Multiprocessing helper function for each job + :meth:`.AcousticModelTrainingMixin.acc_stats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`gmm-sum-accs` + Relevant Kaldi binary + :kaldi_src:`gmm-est` + Relevant Kaldi binary + :kaldi_steps:`train_mono` + Reference Kaldi script + :kaldi_steps:`train_deltas` + Reference Kaldi script + """ + arguments = self.acc_stats_arguments() + + if self.use_mp: + run_mp(acc_stats_func, arguments, self.working_log_directory) + else: + run_non_mp(acc_stats_func, arguments, self.working_log_directory) + + log_path = os.path.join(self.working_log_directory, f"update.{self.iteration}.log") + with open(log_path, "w") as log_file: + acc_files = [] + for a in arguments: + acc_files.extend(a.acc_paths.values()) + sum_proc = subprocess.Popen( + [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est"), + f"--write-occs={self.next_occs_path}", + f"--mix-up={self.current_gaussians}", + f"--power={self.power}", + self.model_path, + "-", + self.next_model_path, + ], + stdin=sum_proc.stdout, + stderr=log_file, + env=os.environ, + ) + est_proc.communicate() + avg_like_pattern = re.compile( + r"Overall avg like per frame \(Gaussian only\) = (?P[-.,\d]+) over (?P[.\d+e]) frames" + ) + average_logdet_pattern = re.compile( + r"Overall average logdet is (?P[-.,\d]+) over (?P[.\d+e]) frames" + ) + avg_like_sum = 0 + avg_like_frames = 0 + average_logdet_sum = 0 + average_logdet_frames = 0 + for a in arguments: + with open(a.log_path, "r", encoding="utf8") as f: + for line in f: + m = re.search(avg_like_pattern, line) + if m: + like = float(m.group("like")) + frames = float(m.group("frames")) + avg_like_sum += like * frames + avg_like_frames += frames + m = re.search(average_logdet_pattern, line) + if m: + logdet = float(m.group("logdet")) + frames = float(m.group("frames")) + average_logdet_sum += logdet * frames + average_logdet_frames += frames + if avg_like_frames: + log_like = avg_like_sum / avg_like_frames + if average_logdet_frames: + log_like += average_logdet_sum / average_logdet_frames + self.logger.debug(f"Likelihood for iteration {self.iteration}: {log_like}") + + if not self.debug: + for f in acc_files: + os.remove(f) + + def parse_iteration_alignments( + self, iteration: Optional[int] = None + ) -> dict[str, list[CtmInterval]]: + """ + Function to parse phone CTMs in a given iteration + + Parameters + ---------- + iteration: int, optional + Iteration to compute over + + Returns + ------- + dict[str, list[CtmInterval]] + Per utterance CtmIntervals + """ + data = {} + for j in self.alignment_improvement_arguments(): + for phone_ctm_path in j.phone_ctm_paths.values(): + if iteration is not None: + phone_ctm_path = phone_ctm_path.replace( + f"phone.{self.iteration}", f"phone.{iteration}" + ) + with open(phone_ctm_path, "r", encoding="utf8") as f: + for line in f: + line = line.strip() + if line == "": + continue + interval = process_ctm_line(line) + if interval.utterance not in data: + data[interval.utterance] = [] + data[interval.utterance].append(interval) + return data + + def compute_alignment_improvement(self) -> None: + """ + Computes aligner improvements in terms of number of aligned files and phone boundaries + for debugging purposes + """ + jobs = self.alignment_improvement_arguments() + if self.use_mp: + run_mp(compute_alignment_improvement_func, jobs, self.working_log_directory) + else: + run_non_mp(compute_alignment_improvement_func, jobs, self.working_log_directory) + + alignment_diff_path = os.path.join(self.working_directory, "train_change.csv") + if self.iteration == 0 or self.iteration not in self.realignment_iterations: + return + ind = self.realignment_iterations.index(self.iteration) + if ind != 0: + previous_iteration = self.realignment_iterations[ind - 1] + else: + previous_iteration = 0 + try: + previous_alignments = self.parse_iteration_alignments(previous_iteration) + except FileNotFoundError: + return + current_alignments = self.parse_iteration_alignments() + utterance_aligned_diff, mean_difference = compare_alignments( + previous_alignments, current_alignments, self.silence_phones + ) + if utterance_aligned_diff: + self.log_warning( + "Cannot compare alignments, install the biopython package to use this functionality." + ) + return + if not os.path.exists(alignment_diff_path): + with open(alignment_diff_path, "w", encoding="utf8") as f: + f.write( + "iteration,number_aligned,number_previously_aligned," + "difference_in_utts_aligned,mean_boundary_change\n" + ) + if self.iteration in self.realignment_iterations: + with open(alignment_diff_path, "a", encoding="utf8") as f: + f.write( + f"{self.iteration},{len(current_alignments)},{len(previous_alignments)}," + f"{utterance_aligned_diff},{mean_difference}\n" + ) + + def train_iteration(self): + """Perform an iteration of training""" + if os.path.exists(self.next_model_path): + self.iteration += 1 + return + if self.iteration in self.realignment_iterations: + self.align_utterances() + self.logger.debug( + f"Analyzing information for alignment in iteration {self.iteration}..." + ) + self.compile_information() + if self.debug: + self.compute_alignment_improvement() + self.acc_stats() + + parse_logs(self.working_log_directory) + if self.iteration < self.final_gaussian_iteration: + self.increment_gaussians() + self.iteration += 1 + + def train(self): + """ + Train the model + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + done_path = os.path.join(self.working_directory, "done") + dirty_path = os.path.join(self.working_directory, "dirty") + if os.path.exists(done_path): + self.logger.info(f"{self.identifier} training already done, skipping initialization.") + return + try: + self.initialize_training() + begin = time.time() + with tqdm(initial=1, total=self.num_iterations + 1) as pbar: + while self.iteration < self.num_iterations + 1: + self.train_iteration() + pbar.update(1) + self.finalize_training() + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + with open(done_path, "w"): + pass + self.logger.info("Training complete!") + self.logger.debug(f"Training took {time.time() - begin} seconds") + + @property + def exported_model_path(self) -> str: + """Model path to export to once training is complete""" + return os.path.join(self.working_log_directory, "acoustic_model.zip") + + def finalize_training(self) -> None: + """ + Finalize the training, renaming all final iteration model files as "final", and exporting + the model to be used in the next round alignment + + """ + shutil.copy( + os.path.join(self.working_directory, f"{self.num_iterations+1}.mdl"), + os.path.join(self.working_directory, "final.mdl"), + ) + shutil.copy( + os.path.join(self.working_directory, f"{self.num_iterations+1}.occs"), + os.path.join(self.working_directory, "final.occs"), + ) + self.export_model(self.exported_model_path) + if not self.debug: + for i in range(1, self.num_iterations + 1): + model_path = os.path.join(self.working_directory, f"{i}.mdl") + try: + os.remove(model_path) + except FileNotFoundError: + pass + try: + os.remove(os.path.join(self.working_directory, f"{i}.occs")) + except FileNotFoundError: + pass + self.training_complete = True + self.worker.current_trainer = None + + @property + def final_gaussian_iteration(self) -> int: + """Final iteration to increase gaussians""" + return self.num_iterations - 10 + + @property + def gaussian_increment(self) -> int: + """Amount by which gaussians should be increases each iteration""" + return int((self.max_gaussians - self.initial_gaussians) / self.final_gaussian_iteration) + + @property + def train_type(self) -> str: + """Training type, not implemented for BaseTrainer""" + raise NotImplementedError + + @property + def phone_type(self) -> str: + """Phone type, not implemented for BaseTrainer""" + raise NotImplementedError + + @property + def meta(self) -> MetaDict: + """Generate metadata for the acoustic model that was trained""" + from datetime import datetime + + from ..utils import get_mfa_version + + data = { + "phones": sorted(self.non_silence_phones), + "version": get_mfa_version(), + "architecture": self.architecture, + "train_date": str(datetime.now()), + "features": self.feature_options, + "multilingual_ipa": self.multilingual_ipa, + } + if self.multilingual_ipa: + data["strip_diacritics"] = self.strip_diacritics + data["digraphs"] = self.digraphs + return data + + def export_model(self, output_model_path: str) -> None: + """ + Export an acoustic model to the specified path + + Parameters + ---------- + output_model_path : str + Path to save acoustic model + """ + directory, filename = os.path.split(output_model_path) + basename, _ = os.path.splitext(filename) + acoustic_model = AcousticModel.empty(basename, root_directory=self.working_log_directory) + acoustic_model.add_meta_file(self) + acoustic_model.add_model(self.working_directory) + if directory: + os.makedirs(directory, exist_ok=True) + basename, _ = os.path.splitext(output_model_path) + acoustic_model.dump(output_model_path) diff --git a/montreal_forced_aligner/acoustic_modeling/lda.py b/montreal_forced_aligner/acoustic_modeling/lda.py new file mode 100644 index 00000000..51400c09 --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/lda.py @@ -0,0 +1,479 @@ +"""Class definitions for LDA trainer""" +from __future__ import annotations + +import os +import shutil +import subprocess +from typing import TYPE_CHECKING, NamedTuple + +from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer +from montreal_forced_aligner.utils import parse_logs, run_mp, run_non_mp, thirdparty_binary + +if TYPE_CHECKING: + from montreal_forced_aligner.abc import MetaDict + + +__all__ = ["LdaTrainer"] + + +class LdaAccStatsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.lda_acc_stats_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ali_paths: dict[str, str] + model_path: str + lda_options: MetaDict + acc_paths: dict[str, str] + + +class CalcLdaMlltArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.calc_lda_mllt_func`""" + + log_path: str + """Log file to save stderr""" + dictionaries: list[str] + feature_strings: dict[str, str] + ali_paths: dict[str, str] + model_path: str + lda_options: MetaDict + macc_paths: dict[str, str] + + +def lda_acc_stats_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ali_paths: dict[str, str], + model_path: str, + lda_options: MetaDict, + acc_paths: dict[str, str], +) -> None: + """ + Multiprocessing function to accumulate LDA stats + + See Also + -------- + :meth:`.LdaTrainer.lda_acc_stats` + Main function that calls this function in parallel + :meth:`.LdaTrainer.lda_acc_stats_arguments` + Job method for generating arguments for this function + :kaldi_src:`ali-to-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + :kaldi_src:`acc-lda` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + model_path: str + Path to the acoustic model file + lda_options: dict[str, Any] + Options for LDA + acc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + ali_path = ali_paths[dict_name] + feature_string = feature_strings[dict_name] + acc_path = acc_paths[dict_name] + ali_to_post_proc = subprocess.Popen( + [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + weight_silence_post_proc = subprocess.Popen( + [ + thirdparty_binary("weight-silence-post"), + f"{lda_options['boost_silence']}", + lda_options["silence_csl"], + model_path, + "ark:-", + "ark:-", + ], + stdin=ali_to_post_proc.stdout, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + acc_lda_post_proc = subprocess.Popen( + [ + thirdparty_binary("acc-lda"), + f"--rand-prune={lda_options['random_prune']}", + model_path, + feature_string, + "ark,s,cs:-", + acc_path, + ], + stdin=weight_silence_post_proc.stdout, + stderr=log_file, + env=os.environ, + ) + acc_lda_post_proc.communicate() + + +def calc_lda_mllt_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ali_paths: dict[str, str], + model_path: str, + lda_options: MetaDict, + macc_paths: dict[str, str], +) -> None: + """ + Multiprocessing function for estimating LDA with MLLT. + + See Also + -------- + :meth:`.LdaTrainer.calc_lda_mllt` + Main function that calls this function in parallel + :meth:`.LdaTrainer.calc_lda_mllt_arguments` + Job method for generating arguments for this function + :kaldi_src:`ali-to-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + :kaldi_src:`gmm-acc-mllt` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + model_path: str + Path to the acoustic model file + lda_options: dict[str, Any] + Options for LDA + macc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + """ + # Estimating MLLT + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + ali_path = ali_paths[dict_name] + feature_string = feature_strings[dict_name] + macc_path = macc_paths[dict_name] + post_proc = subprocess.Popen( + [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + + weight_proc = subprocess.Popen( + [ + thirdparty_binary("weight-silence-post"), + "0.0", + lda_options["silence_csl"], + model_path, + "ark:-", + "ark:-", + ], + stdin=post_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + acc_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-acc-mllt"), + f"--rand-prune={lda_options['random_prune']}", + model_path, + feature_string, + "ark,s,cs:-", + macc_path, + ], + stdin=weight_proc.stdout, + stderr=log_file, + env=os.environ, + ) + acc_proc.communicate() + + +class LdaTrainer(TriphoneTrainer): + """ + Triphone trainer + + Parameters + ---------- + subset : int + Number of utterances to use, defaults to 10000 + num_leaves : int + Number of states in the decision tree, defaults to 2500 + max_gaussians : int + Number of gaussians in the decision tree, defaults to 15000 + lda_dimension : int + Dimensionality of the LDA matrix + uses_splices : bool + Flag to use spliced and LDA calculation + splice_left_context : int or None + Number of frames to splice on the left for calculating LDA + splice_right_context : int or None + Number of frames to splice on the right for calculating LDA + random_prune : float + This is approximately the ratio by which we will speed up the + LDA and MLLT calculations via randomized pruning + + See Also + -------- + :class:`~montreal_forced_aligner.acoustic_modeling.triphone.TriphoneTrainer` + For acoustic model training parsing parameters + + Attributes + ---------- + mllt_iterations : list + List of iterations to perform MLLT estimation + """ + + def __init__( + self, + subset: int = 10000, + num_leaves: int = 2500, + max_gaussians=15000, + lda_dimension: int = 40, + uses_splices: bool = True, + splice_left_context: int = 3, + splice_right_context: int = 3, + random_prune=4.0, + **kwargs, + ): + super().__init__(**kwargs) + self.subset = subset + self.num_leaves = num_leaves + self.max_gaussians = max_gaussians + self.lda_dimension = lda_dimension + self.random_prune = random_prune + self.uses_splices = uses_splices + self.splice_left_context = splice_left_context + self.splice_right_context = splice_right_context + + def lda_acc_stats_arguments(self) -> list[LdaAccStatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.lda_acc_stats_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments`] + Arguments for processing + """ + feat_strings = self.worker.construct_feature_proc_strings() + return [ + LdaAccStatsArguments( + os.path.join(self.working_log_directory, f"lda_acc_stats.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + j.construct_path_dictionary(self.previous_aligner.working_directory, "ali", "ark"), + self.previous_aligner.alignment_model_path, + self.lda_options, + j.construct_path_dictionary(self.working_directory, "lda", "acc"), + ) + for j in self.jobs + ] + + def calc_lda_mllt_arguments(self) -> list[CalcLdaMlltArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.calc_lda_mllt_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments`] + Arguments for processing + """ + feat_strings = self.worker.construct_feature_proc_strings() + return [ + CalcLdaMlltArguments( + os.path.join( + self.working_log_directory, f"lda_mllt.{self.iteration}.{j.name}.log" + ), + j.current_dictionary_names, + feat_strings[j.name], + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + self.model_path, + self.lda_options, + j.construct_path_dictionary(self.working_directory, "lda", "macc"), + ) + for j in self.jobs + ] + + @property + def train_type(self) -> str: + """Training identifier""" + return "lda" + + @property + def lda_options(self) -> MetaDict: + """Options for computing LDA""" + return { + "lda_dimension": self.lda_dimension, + "boost_silence": self.boost_silence, + "random_prune": self.random_prune, + "silence_csl": self.silence_csl, + } + + def compute_calculated_properties(self) -> None: + """Generate realignment iterations, MLLT estimation iterations, and initial gaussians based on configuration""" + super().compute_calculated_properties() + self.mllt_iterations = [] + max_mllt_iter = int(self.num_iterations / 2) - 1 + for i in range(1, max_mllt_iter): + if i < max_mllt_iter / 2 and i % 2 == 0: + self.mllt_iterations.append(i) + self.mllt_iterations.append(max_mllt_iter) + + def lda_acc_stats(self) -> None: + """ + Multiprocessing function that accumulates LDA statistics. + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.lda.lda_acc_stats_func` + Multiprocessing helper function for each job + :meth:`.LdaTrainer.lda_acc_stats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`est-lda` + Relevant Kaldi binary + :kaldi_steps:`train_lda_mllt` + Reference Kaldi script + + """ + arguments = self.lda_acc_stats_arguments() + + if self.use_mp: + run_mp(lda_acc_stats_func, arguments, self.working_log_directory) + else: + run_non_mp(lda_acc_stats_func, arguments, self.working_log_directory) + + log_path = os.path.join(self.working_log_directory, "lda_est.log") + acc_list = [] + for x in arguments: + acc_list.extend(x.acc_paths.values()) + with open(log_path, "w", encoding="utf8") as log_file: + est_lda_proc = subprocess.Popen( + [ + thirdparty_binary("est-lda"), + f"--dim={self.lda_dimension}", + os.path.join(self.working_directory, "lda.mat"), + ] + + acc_list, + stderr=log_file, + env=os.environ, + ) + est_lda_proc.communicate() + shutil.copyfile( + os.path.join(self.working_directory, "lda.mat"), + os.path.join(self.worker.working_directory, "lda.mat"), + ) + + def _trainer_initialization(self) -> None: + """Initialize LDA training""" + self.uses_splices = True + self.worker.uses_splices = True + self.lda_acc_stats() + super()._trainer_initialization() + + def calc_lda_mllt(self) -> None: + """ + Multiprocessing function that calculates LDA+MLLT transformations. + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.lda.calc_lda_mllt_func` + Multiprocessing helper function for each job + :meth:`.LdaTrainer.calc_lda_mllt_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`est-mllt` + Relevant Kaldi binary + :kaldi_src:`gmm-transform-means` + Relevant Kaldi binary + :kaldi_src:`compose-transforms` + Relevant Kaldi binary + :kaldi_steps:`train_lda_mllt` + Reference Kaldi script + + """ + jobs = self.calc_lda_mllt_arguments() + + if self.use_mp: + run_mp(calc_lda_mllt_func, jobs, self.working_log_directory) + else: + run_non_mp(calc_lda_mllt_func, jobs, self.working_log_directory) + + log_path = os.path.join( + self.working_log_directory, f"transform_means.{self.iteration}.log" + ) + previous_mat_path = os.path.join(self.working_directory, "lda.mat") + new_mat_path = os.path.join(self.working_directory, "lda_new.mat") + composed_path = os.path.join(self.working_directory, "lda_composed.mat") + with open(log_path, "a", encoding="utf8") as log_file: + macc_list = [] + for x in jobs: + macc_list.extend(x.macc_paths.values()) + subprocess.call( + [thirdparty_binary("est-mllt"), new_mat_path] + macc_list, + stderr=log_file, + env=os.environ, + ) + subprocess.call( + [ + thirdparty_binary("gmm-transform-means"), + new_mat_path, + self.model_path, + self.model_path, + ], + stderr=log_file, + env=os.environ, + ) + + if os.path.exists(previous_mat_path): + subprocess.call( + [ + thirdparty_binary("compose-transforms"), + new_mat_path, + previous_mat_path, + composed_path, + ], + stderr=log_file, + env=os.environ, + ) + os.remove(previous_mat_path) + os.rename(composed_path, previous_mat_path) + else: + os.rename(new_mat_path, previous_mat_path) + + def train_iteration(self): + """ + Run a single LDA training iteration + """ + if os.path.exists(self.next_model_path): + return + if self.iteration in self.realignment_iterations: + self.align_utterances() + if self.debug: + self.compute_alignment_improvement() + if self.iteration in self.mllt_iterations: + self.calc_lda_mllt() + + self.acc_stats() + parse_logs(self.working_log_directory) + if self.iteration < self.final_gaussian_iteration: + self.increment_gaussians() + self.iteration += 1 diff --git a/montreal_forced_aligner/acoustic_modeling/monophone.py b/montreal_forced_aligner/acoustic_modeling/monophone.py new file mode 100644 index 00000000..53827ad8 --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/monophone.py @@ -0,0 +1,284 @@ +"""Class definitions for Monophone trainer""" +from __future__ import annotations + +import os +import re +import subprocess +from typing import NamedTuple + +from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin +from montreal_forced_aligner.utils import run_mp, run_non_mp, thirdparty_binary + + +class MonoAlignEqualArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.monophone.mono_align_equal_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + fst_scp_paths: dict[str, str] + ali_ark_paths: dict[str, str] + acc_paths: dict[str, str] + model_path: str + + +def mono_align_equal_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + fst_scp_paths: dict[str, str], + ali_ark_paths: dict[str, str], + acc_paths: dict[str, str], + model_path: str, +): + """ + Multiprocessing function for initializing monophone alignments + + See Also + -------- + :meth:`.MonophoneTrainer.mono_align_equal` + Main function that calls this function in parallel + :meth:`.MonophoneTrainer.mono_align_equal_arguments` + Job method for generating arguments for this function + :kaldi_src:`align-equal-compiled` + Relevant Kaldi binary + :kaldi_src:`gmm-acc-stats-ali` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + fst_scp_paths: dict[str, str] + Dictionary of utterance FST scp files per dictionary name + ali_ark_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + acc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + model_path: str + Path to the acoustic model file + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + fst_path = fst_scp_paths[dict_name] + ali_path = ali_ark_paths[dict_name] + acc_path = acc_paths[dict_name] + align_proc = subprocess.Popen( + [ + thirdparty_binary("align-equal-compiled"), + f"scp:{fst_path}", + feature_strings[dict_name], + f"ark:{ali_path}", + ], + stderr=log_file, + env=os.environ, + ) + align_proc.communicate() + stats_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-acc-stats-ali"), + "--binary=true", + model_path, + feature_strings[dict_name], + f"ark:{ali_path}", + acc_path, + ], + stdin=align_proc.stdout, + stderr=log_file, + env=os.environ, + ) + stats_proc.communicate() + + +__all__ = ["MonophoneTrainer"] + + +class MonophoneTrainer(AcousticModelTrainingMixin): + """ + Configuration class for monophone training + + Attributes + ---------- + subset : int + Number of utterances to use, defaults to 2000 + initial_gaussians : int + Number of gaussians to begin training, defaults to 135 + max_gaussians : int + Total number of gaussians, defaults to 1000 + power : float + Exponent for number of gaussians according to occurrence counts, defaults to 0.25 + + See Also + -------- + :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin` + For acoustic model training parsing parameters + """ + + def __init__( + self, + subset: int = 2000, + initial_gaussians: int = 135, + max_gaussians: int = 1000, + power: float = 0.25, + **kwargs, + ): + super().__init__(**kwargs) + self.subset = subset + self.initial_gaussians = initial_gaussians + self.max_gaussians = max_gaussians + self.power = power + + def mono_align_equal_arguments(self) -> list[MonoAlignEqualArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.monophone.mono_align_equal_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.monophone.MonoAlignEqualArguments`] + Arguments for processing + """ + feat_strings = self.worker.construct_feature_proc_strings() + return [ + MonoAlignEqualArguments( + os.path.join(self.working_log_directory, f"mono_align_equal.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + j.construct_path_dictionary(self.working_directory, "fsts", "scp"), + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.working_directory, "0", "acc"), + self.model_path, + ) + for j in self.jobs + ] + + def compute_calculated_properties(self) -> None: + """Generate realignment iterations and initial gaussians based on configuration""" + self.realignment_iterations = [0] + for i in range(1, self.num_iterations): + if i <= int(self.num_iterations / 4): + self.realignment_iterations.append(i) + elif i <= int(self.num_iterations * 2 / 4): + if i - self.realignment_iterations[-1] > 1: + self.realignment_iterations.append(i) + else: + if i - self.realignment_iterations[-1] > 2: + self.realignment_iterations.append(i) + + @property + def train_type(self) -> str: + """Training identifier""" + return "mono" + + @property + def phone_type(self) -> str: + """Phone type""" + return "monophone" + + def mono_align_equal(self): + """ + Multiprocessing function that creates equal alignments for base monophone training. + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.monophone.mono_align_equal_func` + Multiprocessing helper function for each job + :meth:`.MonophoneTrainer.mono_align_equal_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`gmm-sum-accs` + Relevant Kaldi binary + :kaldi_src:`gmm-est` + Relevant Kaldi binary + :kaldi_steps:`train_mono` + Reference Kaldi script + """ + + arguments = self.mono_align_equal_arguments() + + if self.use_mp: + run_mp(mono_align_equal_func, arguments, self.working_log_directory) + else: + run_non_mp(mono_align_equal_func, arguments, self.working_log_directory) + + log_path = os.path.join(self.working_log_directory, "update.0.log") + with open(log_path, "w") as log_file: + acc_files = [] + for x in arguments: + acc_files.extend(sorted(x.acc_paths.values())) + sum_proc = subprocess.Popen( + [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est"), + "--min-gaussian-occupancy=3", + f"--mix-up={self.current_gaussians}", + f"--power={self.power}", + self.model_path, + "-", + self.next_model_path, + ], + stderr=log_file, + stdin=sum_proc.stdout, + env=os.environ, + ) + est_proc.communicate() + if not self.debug: + for f in acc_files: + os.remove(f) + + def _trainer_initialization(self) -> None: + """Monophone training initialization""" + self.iteration = 0 + tree_path = os.path.join(self.working_directory, "tree") + + feat_dim = self.worker.get_feat_dim() + + feature_string = self.worker.construct_base_feature_string() + shared_phones_path = os.path.join(self.worker.phones_dir, "sets.int") + init_log_path = os.path.join(self.working_log_directory, "init.log") + temp_feats_path = os.path.join(self.working_directory, "temp_feats") + with open(init_log_path, "w") as log_file: + subprocess.call( + [ + thirdparty_binary("subset-feats"), + "--n=10", + feature_string, + f"ark:{temp_feats_path}", + ], + stderr=log_file, + ) + subprocess.call( + [ + thirdparty_binary("gmm-init-mono"), + f"--shared-phones={shared_phones_path}", + f"--train-feats=ark:{temp_feats_path}", + os.path.join(self.worker.topo_path), + str(feat_dim), + self.model_path, + tree_path, + ], + stderr=log_file, + ) + proc = subprocess.Popen( + [thirdparty_binary("gmm-info"), "--print-args=false", self.model_path], + stderr=log_file, + stdout=subprocess.PIPE, + ) + stdout, stderr = proc.communicate() + num = stdout.decode("utf8") + matches = re.search(r"gaussians (\d+)", num) + num_gauss = int(matches.groups()[0]) + if os.path.exists(self.model_path): + os.remove(init_log_path) + os.remove(temp_feats_path) + self.initial_gaussians = num_gauss + self.current_gaussians = num_gauss + self.compile_train_graphs() + self.mono_align_equal() diff --git a/montreal_forced_aligner/acoustic_modeling/sat.py b/montreal_forced_aligner/acoustic_modeling/sat.py new file mode 100644 index 00000000..df6aab04 --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/sat.py @@ -0,0 +1,323 @@ +"""Class definitions for Speaker Adapted Triphone trainer""" +from __future__ import annotations + +import os +import shutil +import subprocess +import time +from typing import NamedTuple + +from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer +from montreal_forced_aligner.exceptions import KaldiProcessingError +from montreal_forced_aligner.utils import ( + log_kaldi_errors, + parse_logs, + run_mp, + run_non_mp, + thirdparty_binary, +) + +__all__ = ["SatTrainer"] + + +class AccStatsTwoFeatsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.sat.acc_stats_two_feats_func`""" + + log_path: str + dictionaries: list[str] + ali_paths: dict[str, str] + acc_paths: dict[str, str] + model_path: str + feature_strings: dict[str, str] + si_feature_strings: dict[str, str] + + +def acc_stats_two_feats_func( + log_path: str, + dictionaries: list[str], + ali_paths: dict[str, str], + acc_paths: dict[str, str], + model_path: str, + feature_strings: dict[str, str], + si_feature_strings: dict[str, str], +) -> None: + """ + Multiprocessing function for accumulating stats across speaker-independent and + speaker-adapted features + + See Also + -------- + :meth:`.SatTrainer.create_align_model` + Main function that calls this function in parallel + :meth:`.SatTrainer.acc_stats_two_feats_arguments` + Job method for generating arguments for this function + :kaldi_src:`ali-to-post` + Relevant Kaldi binary + :kaldi_src:`gmm-acc-stats-twofeats` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + acc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + model_path: str + Path to the acoustic model file + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + si_feature_strings: dict[str, str] + Dictionary of speaker-independent feature strings per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + ali_path = ali_paths[dict_name] + acc_path = acc_paths[dict_name] + feature_string = feature_strings[dict_name] + si_feature_string = si_feature_strings[dict_name] + ali_to_post_proc = subprocess.Popen( + [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + acc_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-acc-stats-twofeats"), + model_path, + feature_string, + si_feature_string, + "ark,s,cs:-", + acc_path, + ], + stderr=log_file, + stdin=ali_to_post_proc.stdout, + env=os.environ, + ) + acc_proc.communicate() + + +class SatTrainer(TriphoneTrainer): + """ + Speaker adapted trainer (SAT), inherits from TriphoneTrainer + + Parameters + ---------- + subset : int + Number of utterances to use, defaults to 10000 + num_leaves : int + Number of states in the decision tree, defaults to 2500 + max_gaussians : int + Number of gaussians in the decision tree, defaults to 15000 + power : float + Exponent for number of gaussians according to occurrence counts, defaults to 0.2 + + See Also + -------- + :class:`~montreal_forced_aligner.acoustic_modeling.triphone.TriphoneTrainer` + For acoustic model training parsing parameters + + Attributes + ---------- + fmllr_iterations : list + List of iterations to perform fMLLR calculation + """ + + def __init__( + self, + subset: int = 10000, + num_leaves: int = 2500, + max_gaussians: int = 15000, + power: float = 0.2, + **kwargs, + ): + super().__init__(**kwargs) + self.subset = subset + self.num_leaves = num_leaves + self.max_gaussians = max_gaussians + self.power = power + self.fmllr_iterations = [] + + def acc_stats_two_feats_arguments(self) -> list[AccStatsTwoFeatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.sat.acc_stats_two_feats_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.sat.AccStatsTwoFeatsArguments`] + Arguments for processing + """ + feat_strings = self.worker.construct_feature_proc_strings() + si_feat_strings = self.worker.construct_feature_proc_strings(speaker_independent=True) + return [ + AccStatsTwoFeatsArguments( + os.path.join(self.working_log_directory, f"acc_stats_two_feats.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.working_directory, "two_feat_acc", "ark"), + self.model_path, + feat_strings[j.name], + si_feat_strings[j.name], + ) + for j in self.jobs + ] + + def compute_calculated_properties(self) -> None: + """Generate realignment iterations, initial gaussians, and fMLLR iterations based on configuration""" + super().compute_calculated_properties() + self.fmllr_iterations = [] + max_fmllr_iter = int(self.num_iterations / 2) - 1 + for i in range(1, max_fmllr_iter): + if i < max_fmllr_iter / 2 and i % 2 == 0: + self.fmllr_iterations.append(i) + self.fmllr_iterations.append(max_fmllr_iter) + + def _trainer_initialization(self) -> None: + """Speaker adapted training initialization""" + self.speaker_independent = False + if os.path.exists(os.path.join(self.working_directory, "1.mdl")): + return + if os.path.exists(os.path.join(self.previous_aligner.working_directory, "lda.mat")): + shutil.copyfile( + os.path.join(self.previous_aligner.working_directory, "lda.mat"), + os.path.join(self.working_directory, "lda.mat"), + ) + self.tree_stats() + self._setup_tree() + + self.compile_train_graphs() + + self.convert_alignments() + os.rename(self.model_path, self.next_model_path) + + self.iteration = 1 + if os.path.exists(os.path.join(self.previous_aligner.working_directory, "trans.0.ark")): + for j in self.jobs: + for path in j.construct_path_dictionary( + self.previous_aligner.working_directory, "trans", "ark" + ).values(): + shutil.copy( + path, + path.replace( + self.previous_aligner.working_directory, self.working_directory + ), + ) + else: + + self.calc_fmllr() + self.initial_fmllr = False + parse_logs(self.working_log_directory) + + def finalize_training(self) -> None: + """ + Finalize training and create a speaker independent model for initial alignment + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + try: + self.create_align_model() + super().finalize_training() + shutil.copy( + os.path.join(self.working_directory, f"{self.num_iterations+1}.alimdl"), + os.path.join(self.working_directory, "final.alimdl"), + ) + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + def train_iteration(self) -> None: + """ + Run a single training iteration + """ + if os.path.exists(self.next_model_path): + self.iteration += 1 + return + if self.iteration in self.realignment_iterations: + self.align_utterances() + if self.debug: + self.compute_alignment_improvement() + if self.iteration in self.fmllr_iterations: + self.calc_fmllr() + + self.acc_stats() + parse_logs(self.working_log_directory) + if self.iteration < self.final_gaussian_iteration: + self.increment_gaussians() + self.iteration += 1 + + @property + def alignment_model_path(self) -> str: + """Alignment model path""" + path = self.model_path.replace(".mdl", ".alimdl") + if os.path.exists(path): + return path + return self.model_path + + def create_align_model(self) -> None: + """ + Create alignment model for speaker-adapted training that will use speaker-independent + features in later aligning. + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.sat.acc_stats_two_feats_func` + Multiprocessing helper function for each job + :meth:`.SatTrainer.acc_stats_two_feats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`gmm-est` + Relevant Kaldi binary + :kaldi_src:`gmm-sum-accs` + Relevant Kaldi binary + :kaldi_steps:`train_sat` + Reference Kaldi script + """ + self.logger.info("Creating alignment model for speaker-independent features...") + begin = time.time() + log_directory = self.working_log_directory + + arguments = self.acc_stats_two_feats_arguments() + if self.use_mp: + run_mp(acc_stats_two_feats_func, arguments, log_directory) + else: + run_non_mp(acc_stats_two_feats_func, arguments, log_directory) + + log_path = os.path.join(self.working_log_directory, "align_model_est.log") + with open(log_path, "w", encoding="utf8") as log_file: + + acc_files = [] + for x in arguments: + acc_files.extend(x.acc_paths.values()) + sum_proc = subprocess.Popen( + [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est"), + "--remove-low-count-gaussians=false", + f"--power={self.power}", + self.model_path, + "-", + self.model_path.replace(".mdl", ".alimdl"), + ], + stdin=sum_proc.stdout, + stderr=log_file, + env=os.environ, + ) + est_proc.communicate() + parse_logs(self.working_log_directory) + if not self.debug: + for f in acc_files: + os.remove(f) + self.logger.debug(f"Alignment model creation took {time.time() - begin}") diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py new file mode 100644 index 00000000..261a06cd --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -0,0 +1,358 @@ +"""Class definitions for trainable aligners""" +from __future__ import annotations + +import os +import time +from typing import TYPE_CHECKING, Any, Optional + +import yaml + +from montreal_forced_aligner.abc import ModelExporterMixin, TopLevelMfaWorker +from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin +from montreal_forced_aligner.acoustic_modeling.lda import LdaTrainer +from montreal_forced_aligner.acoustic_modeling.monophone import MonophoneTrainer +from montreal_forced_aligner.acoustic_modeling.sat import SatTrainer +from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer +from montreal_forced_aligner.alignment.base import CorpusAligner +from montreal_forced_aligner.exceptions import ConfigError, KaldiProcessingError +from montreal_forced_aligner.helper import parse_old_features +from montreal_forced_aligner.models import AcousticModel +from montreal_forced_aligner.utils import log_kaldi_errors + +if TYPE_CHECKING: + from argparse import Namespace + + from montreal_forced_aligner.abc import MetaDict + +__all__ = ["TrainableAligner"] + + +class TrainableAligner(CorpusAligner, TopLevelMfaWorker, ModelExporterMixin): + """ + Train acoustic model + + Parameters + ---------- + training_configuration : list[tuple[str, dict[str, Any]]] + Training identifiers and parameters for training blocks + + See Also + -------- + :class:`~montreal_forced_aligner.alignment.base.CorpusAligner` + For dictionary and corpus parsing parameters and alignment parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + :class:`~montreal_forced_aligner.abc.ModelExporterMixin` + For model export parameters + + Attributes + ---------- + param_dict: dict[str, Any] + Parameters to pass to training blocks + final_identifier: str + Identifier of the final training block + current_subset: int + Current training block's subset + current_acoustic_model: :class:`~montreal_forced_aligner.models.AcousticModel` + Acoustic model to use in aligning, based on previous training block + training_configs: dict[str, :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin`] + Training blocks + """ + + def __init__(self, training_configuration: list[tuple[str, dict[str, Any]]] = None, **kwargs): + self.param_dict = { + k: v + for k, v in kwargs.items() + if not k.endswith("_directory") + and not k.endswith("_path") + and k not in ["clean", "num_jobs", "speaker_characters"] + } + self.final_identifier = None + self.current_subset: int = 0 + self.current_aligner = None + self.current_trainer = None + self.current_acoustic_model: Optional[AcousticModel] = None + super().__init__(**kwargs) + os.makedirs(self.output_directory, exist_ok=True) + self.training_configs: dict[str, AcousticModelTrainingMixin] = {} + if training_configuration is None: + training_configuration = [ + ("monophone", {}), + ("triphone", {}), + ("lda", {}), + ("sat", {}), + ("sat", {"subset": 0, "num_leaves": 4200, "max_gaussians": 40000}), + ] + for k, v in training_configuration: + self.add_config(k, v) + + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse configuration parameters from a config file and command line arguments + + Parameters + ---------- + config_path: str, optional + Path to yaml configuration file + args: :class:`~argparse.Namespace`, optional + Arguments parsed by argparse + unknown_args: list[str], optional + List of unknown arguments from argparse + + Returns + ------- + dict[str, Any] + Dictionary of specified configuration parameters + """ + global_params = {} + training_params = [] + if config_path: + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + training_params = [] + for k, v in data.items(): + if k == "training": + for t in v: + for k2, v2 in t.items(): + if "features" in v2: + global_params.update(parse_old_features(v2["features"])) + del v2["features"] + training_params.append((k2, v2)) + elif k == "features": + global_params.update(parse_old_features(v)) + else: + global_params[k] = v + if not training_params: + raise ConfigError(f"No 'training' block found in {config_path}") + else: # default training configuration + training_params.append(("monophone", {})) + training_params.append(("triphone", {})) + training_params.append(("lda", {})) + training_params.append(("sat", {})) + training_params.append( + ("sat", {"subset": 0, "num_leaves": 4200, "max_gaussians": 40000}) + ) + if training_params: + if training_params[0][0] != "monophone": + raise ConfigError("The first round of training must be monophone.") + global_params["training_configuration"] = training_params + global_params.update(cls.parse_args(args, unknown_args)) + return global_params + + def setup(self) -> None: + """Setup for acoustic model training""" + if self.initialized: + return + self.check_previous_run() + try: + self.load_corpus() + for config in self.training_configs.values(): + config.non_silence_phones = self.non_silence_phones + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + self.initialized = True + + @property + def workflow_identifier(self) -> str: + """Acoustic model training identifier""" + return "train_acoustic_model" + + @property + def configuration(self) -> MetaDict: + """Configuration for the worker""" + config = super().configuration + config.update( + { + "dictionary_path": self.dictionary_model.path, + "corpus_directory": self.corpus_directory, + } + ) + return config + + @property + def meta(self) -> MetaDict: + """Metadata about the final round of training""" + return self.training_configs[self.final_identifier].meta + + def add_config(self, train_type: str, params: MetaDict) -> None: + """ + Add a trainer to the pipeline + + Parameters + ---------- + train_type: str + Type of trainer to add, one of ``monophone``, ``triphone``, ``lda`` or ``sat`` + params: dict[str, Any] + Parameters to initialize trainer + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.ConfigError` + If an invalid train_type is specified + """ + p = {} + p.update(self.param_dict) + p.update(params) + identifier = train_type + index = 1 + while identifier in self.training_configs: + identifier = f"{train_type}_{index}" + index += 1 + self.final_identifier = identifier + if train_type == "monophone": + p = { + k: v for k, v in p.items() if k in MonophoneTrainer.get_configuration_parameters() + } + config = MonophoneTrainer(identifier=identifier, worker=self, **p) + elif train_type == "triphone": + p = {k: v for k, v in p.items() if k in TriphoneTrainer.get_configuration_parameters()} + config = TriphoneTrainer(identifier=identifier, worker=self, **p) + elif train_type == "lda": + p = {k: v for k, v in p.items() if k in LdaTrainer.get_configuration_parameters()} + config = LdaTrainer(identifier=identifier, worker=self, **p) + elif train_type == "sat": + p = {k: v for k, v in p.items() if k in SatTrainer.get_configuration_parameters()} + config = SatTrainer(identifier=identifier, worker=self, **p) + else: + raise ConfigError(f"Invalid training type '{train_type}' in config file") + + self.training_configs[identifier] = config + + def export_model(self, output_model_path: str) -> None: + """ + Export an acoustic model to the specified path + + Parameters + ---------- + output_model_path : str + Path to save acoustic model + """ + self.training_configs[self.final_identifier].export_model(output_model_path) + self.logger.info(f"Saved model to {output_model_path}") + + @property + def backup_output_directory(self) -> Optional[str]: + """Backup directory if overwriting files is not allowed""" + if self.overwrite: + return None + return os.path.join(self.working_directory, "textgrids") + + @property + def tree_path(self) -> str: + """Tree path of the final model""" + return self.training_configs[self.final_identifier].tree_path + + def train(self, generate_final_alignments: bool = True) -> None: + """ + Run through the training configurations to produce a final acoustic model + + Parameters + ---------- + generate_final_alignments: bool + Flag for whether final alignments should be generated at the end of training, defaults to True + """ + self.setup() + previous = None + begin = time.time() + for trainer in self.training_configs.values(): + self.current_subset = trainer.subset + if previous is not None: + self.current_aligner = previous.identifier + os.makedirs(self.working_directory, exist_ok=True) + self.current_acoustic_model = AcousticModel( + previous.exported_model_path, self.working_directory + ) + self.align() + trainer.train() + previous = trainer + self.logger.info(f"Completed training in {time.time()-begin} seconds!") + + if generate_final_alignments: + self.current_subset = None + self.current_aligner = previous.identifier + os.makedirs(self.working_log_directory, exist_ok=True) + self.current_acoustic_model = AcousticModel( + previous.exported_model_path, self.working_directory + ) + self.align() + + def align(self) -> None: + """ + Multiprocessing function that aligns based on the current model. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.align_func` + Multiprocessing helper function for each job + :meth:`.AlignMixin.align_arguments` + Job method for generating arguments for the helper function + :kaldi_steps:`align_si` + Reference Kaldi script + :kaldi_steps:`align_fmllr` + Reference Kaldi script + """ + done_path = os.path.join(self.working_directory, "done") + if os.path.exists(done_path): + self.logger.debug(f"Skipping {self.current_aligner} alignments") + return + try: + self.current_acoustic_model.export_model(self.working_directory) + self.compile_train_graphs() + self.align_utterances() + if self.current_subset: + self.logger.debug( + f"Analyzing alignment diagnostics for {self.current_aligner} on {self.current_subset} utterances" + ) + else: + self.logger.debug( + f"Analyzing alignment diagnostics for {self.current_aligner} on the full corpus" + ) + self.compile_information() + with open(done_path, "w"): + pass + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + @property + def alignment_model_path(self) -> str: + """Current alignment model path""" + path = os.path.join(self.working_directory, "final.alimdl") + if os.path.exists(path): + return path + return self.model_path + + @property + def model_path(self) -> str: + """Current model path""" + return os.path.join(self.working_directory, "final.mdl") + + @property + def data_directory(self) -> str: + """Current data directory based on the trainer's subset""" + return self.subset_directory(self.current_subset) + + @property + def working_directory(self) -> Optional[str]: + """Working directory""" + if self.current_trainer is not None: + return self.current_trainer.working_directory + if self.current_aligner is None: + return None + return os.path.join(self.output_directory, f"{self.current_aligner}_ali") + + @property + def working_log_directory(self) -> Optional[str]: + """Current log directory""" + return os.path.join(self.working_directory, "log") diff --git a/montreal_forced_aligner/acoustic_modeling/triphone.py b/montreal_forced_aligner/acoustic_modeling/triphone.py new file mode 100644 index 00000000..dfa4233b --- /dev/null +++ b/montreal_forced_aligner/acoustic_modeling/triphone.py @@ -0,0 +1,439 @@ +"""Class definitions for TriphoneTrainer""" +from __future__ import annotations + +import os +import subprocess +from typing import TYPE_CHECKING, NamedTuple + +from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin +from montreal_forced_aligner.utils import parse_logs, run_mp, run_non_mp, thirdparty_binary + +if TYPE_CHECKING: + from ..abc import MetaDict + + +__all__ = ["TriphoneTrainer"] + + +class TreeStatsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.tree_stats_func`""" + + log_path: str + dictionaries: list[str] + ci_phones: str + model_path: str + feature_strings: dict[str, str] + ali_paths: dict[str, str] + treeacc_paths: dict[str, str] + + +class ConvertAlignmentsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.convert_alignments_func`""" + + log_path: str + dictionaries: list[str] + model_path: str + tree_path: str + align_model_path: str + ali_paths: dict[str, str] + new_ali_paths: dict[str, str] + + +def convert_alignments_func( + log_path: str, + dictionaries: list[str], + model_path: str, + tree_path: str, + align_model_path: str, + ali_paths: dict[str, str], + new_ali_paths: dict[str, str], +) -> None: + """ + Multiprocessing function for converting alignments from a previous trainer + + See Also + -------- + :meth:`.TriphoneTrainer.convert_alignments` + Main function that calls this function in parallel + :meth:`.TriphoneTrainer.convert_alignments_arguments` + Job method for generating arguments for this function + :kaldi_src:`convert-ali` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + model_path: str + Path to the acoustic model file + tree_path: str + Path to the acoustic model tree file + align_model_path: str + Path to the alignment acoustic model file + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + new_ali_paths: dict[str, str] + Dictionary of new alignment archives per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + ali_path = ali_paths[dict_name] + new_ali_path = new_ali_paths[dict_name] + subprocess.call( + [ + thirdparty_binary("convert-ali"), + align_model_path, + model_path, + tree_path, + f"ark:{ali_path}", + f"ark:{new_ali_path}", + ], + stderr=log_file, + ) + + +def tree_stats_func( + log_path: str, + dictionaries: list[str], + ci_phones: str, + model_path: str, + feature_strings: dict[str, str], + ali_paths: dict[str, str], + treeacc_paths: dict[str, str], +) -> None: + """ + Multiprocessing function for calculating tree stats for training + + See Also + -------- + :meth:`.TriphoneTrainer.tree_stats` + Main function that calls this function in parallel + :meth:`.TriphoneTrainer.tree_stats_arguments` + Job method for generating arguments for this function + :kaldi_src:`acc-tree-stats` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + ci_phones: str + Colon-separated list of context-independent phones + model_path: str + Path to the acoustic model file + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + treeacc_paths: dict[str, str] + Dictionary of accumulated tree stats files per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + ali_path = ali_paths[dict_name] + treeacc_path = treeacc_paths[dict_name] + subprocess.call( + [ + thirdparty_binary("acc-tree-stats"), + f"--ci-phones={ci_phones}", + model_path, + feature_string, + f"ark:{ali_path}", + treeacc_path, + ], + stderr=log_file, + ) + + +class TriphoneTrainer(AcousticModelTrainingMixin): + """ + Triphone trainer + + Parameters + ---------- + subset : int + Number of utterances to use, defaults to 5000 + num_iterations : int + Number of training iterations to perform, defaults to 35 + num_leaves : int + Number of states in the decision tree, defaults to 1000 + max_gaussians : int + Number of gaussians in the decision tree, defaults to 10000 + cluster_threshold : int + For build-tree control final bottom-up clustering of leaves, defaults to 100 + + See Also + -------- + :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin` + For acoustic model training parsing parameters + """ + + def __init__( + self, + subset: int = 5000, + num_iterations: int = 35, + num_leaves: int = 1000, + max_gaussians: int = 10000, + cluster_threshold: int = -1, + **kwargs, + ): + super().__init__(**kwargs) + self.subset = subset + self.num_iterations = num_iterations + self.num_leaves = num_leaves + self.max_gaussians = max_gaussians + self.cluster_threshold = cluster_threshold + + def tree_stats_arguments(self) -> list[TreeStatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.tree_stats_func` + + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.triphone.TreeStatsArguments`] + Arguments for processing + """ + feat_strings = self.worker.construct_feature_proc_strings() + alignment_model_path = os.path.join(self.previous_aligner.working_directory, "final.mdl") + return [ + TreeStatsArguments( + os.path.join(self.working_log_directory, f"acc_tree.{j.name}.log"), + j.current_dictionary_names, + self.worker.silence_csl, + alignment_model_path, + feat_strings[j.name], + j.construct_path_dictionary(self.previous_aligner.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.working_directory, "tree", "acc"), + ) + for j in self.jobs + ] + + def convert_alignments_arguments(self) -> list[ConvertAlignmentsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.convert_alignments_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsArguments`] + Arguments for processing + """ + return [ + ConvertAlignmentsArguments( + os.path.join(self.working_log_directory, f"convert_alignments.{j.name}.log"), + j.current_dictionary_names, + self.model_path, + self.tree_path, + self.previous_aligner.alignment_model_path, + j.construct_path_dictionary(self.previous_aligner.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + ) + for j in self.jobs + ] + + def convert_alignments(self) -> None: + """ + Multiprocessing function that converts alignments from previous training + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.triphone.convert_alignments_func` + Multiprocessing helper function for each job + :meth:`.TriphoneTrainer.convert_alignments_arguments` + Job method for generating arguments for the helper function + :kaldi_steps:`train_deltas` + Reference Kaldi script + :kaldi_steps:`train_lda_mllt` + Reference Kaldi script + :kaldi_steps:`train_sat` + Reference Kaldi script + + """ + + jobs = self.convert_alignments_arguments() + if self.use_mp: + run_mp(convert_alignments_func, jobs, self.working_log_directory) + else: + run_non_mp(convert_alignments_func, jobs, self.working_log_directory) + + def acoustic_model_training_params(self) -> MetaDict: + """Configuration parameters""" + return { + "num_iterations": self.num_iterations, + "num_leaves": self.num_leaves, + "max_gaussians": self.max_gaussians, + "cluster_threshold": self.cluster_threshold, + } + + def compute_calculated_properties(self) -> None: + """Generate realignment iterations and initial gaussians based on configuration""" + for i in range(0, self.num_iterations, 10): + if i == 0: + continue + self.realignment_iterations.append(i) + self.initial_gaussians = self.num_leaves + self.current_gaussians = self.num_leaves + + @property + def train_type(self) -> str: + """Training identifier""" + return "tri" + + @property + def phone_type(self) -> str: + """Phone type""" + return "triphone" + + def _trainer_initialization(self) -> None: + """Triphone training initialization""" + self.tree_stats() + self._setup_tree() + + self.compile_train_graphs() + + self.convert_alignments() + os.rename(self.model_path, self.next_model_path) + + def tree_stats(self) -> None: + """ + Multiprocessing function that computes stats for decision tree training. + + See Also + -------- + :func:`~montreal_forced_aligner.acoustic_modeling.triphone.tree_stats_func` + Multiprocessing helper function for each job + :meth:`.TriphoneTrainer.tree_stats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`sum-tree-stats` + Relevant Kaldi binary + :kaldi_steps:`train_deltas` + Reference Kaldi script + :kaldi_steps:`train_lda_mllt` + Reference Kaldi script + :kaldi_steps:`train_sat` + Reference Kaldi script + + """ + + jobs = self.tree_stats_arguments() + + if self.use_mp: + run_mp(tree_stats_func, jobs, self.working_log_directory) + else: + run_non_mp(tree_stats_func, jobs, self.working_log_directory) + + tree_accs = [] + for x in jobs: + tree_accs.extend(x.treeacc_paths.values()) + log_path = os.path.join(self.working_log_directory, "sum_tree_acc.log") + with open(log_path, "w", encoding="utf8") as log_file: + subprocess.call( + [ + thirdparty_binary("sum-tree-stats"), + os.path.join(self.working_directory, "treeacc"), + ] + + tree_accs, + stderr=log_file, + ) + if not self.debug: + for f in tree_accs: + os.remove(f) + + def _setup_tree(self) -> None: + """ + Set up the tree for the triphone model + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + log_path = os.path.join(self.working_log_directory, "questions.log") + tree_path = os.path.join(self.working_directory, "tree") + treeacc_path = os.path.join(self.working_directory, "treeacc") + sets_int_path = os.path.join(self.worker.phones_dir, "sets.int") + roots_int_path = os.path.join(self.worker.phones_dir, "roots.int") + extra_question_int_path = os.path.join(self.worker.phones_dir, "extra_questions.int") + topo_path = self.worker.topo_path + questions_path = os.path.join(self.working_directory, "questions.int") + questions_qst_path = os.path.join(self.working_directory, "questions.qst") + with open(log_path, "w") as log_file: + subprocess.call( + [ + thirdparty_binary("cluster-phones"), + treeacc_path, + sets_int_path, + questions_path, + ], + stderr=log_file, + ) + + with open(extra_question_int_path, "r") as inf, open(questions_path, "a") as outf: + for line in inf: + outf.write(line) + + log_path = os.path.join(self.working_log_directory, "compile_questions.log") + with open(log_path, "w") as log_file: + subprocess.call( + [ + thirdparty_binary("compile-questions"), + topo_path, + questions_path, + questions_qst_path, + ], + stderr=log_file, + ) + + log_path = os.path.join(self.working_log_directory, "build_tree.log") + with open(log_path, "w") as log_file: + subprocess.call( + [ + thirdparty_binary("build-tree"), + "--verbose=1", + f"--max-leaves={self.initial_gaussians}", + f"--cluster-thresh={self.cluster_threshold}", + treeacc_path, + roots_int_path, + questions_qst_path, + topo_path, + tree_path, + ], + stderr=log_file, + ) + + log_path = os.path.join(self.working_log_directory, "init_model.log") + occs_path = os.path.join(self.working_directory, "0.occs") + mdl_path = self.model_path + with open(log_path, "w") as log_file: + subprocess.call( + [ + thirdparty_binary("gmm-init-model"), + f"--write-occs={occs_path}", + tree_path, + treeacc_path, + topo_path, + mdl_path, + ], + stderr=log_file, + ) + + log_path = os.path.join(self.working_log_directory, "mixup.log") + with open(log_path, "w") as log_file: + subprocess.call( + [ + thirdparty_binary("gmm-mixup"), + f"--mix-up={self.initial_gaussians}", + mdl_path, + occs_path, + mdl_path, + ], + stderr=log_file, + ) + os.remove(treeacc_path) + os.rename(occs_path, self.next_occs_path) + parse_logs(self.working_log_directory) diff --git a/montreal_forced_aligner/aligner/__init__.py b/montreal_forced_aligner/aligner/__init__.py deleted file mode 100644 index 7af0241c..00000000 --- a/montreal_forced_aligner/aligner/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Aligners -======== - -""" -from .adapting import AdaptingAligner # noqa -from .base import BaseAligner # noqa -from .pretrained import PretrainedAligner # noqa -from .trainable import TrainableAligner # noqa - -__all__ = [ - "AdaptingAligner", - "PretrainedAligner", - "TrainableAligner", - "BaseAligner", - "adapting", - "base", - "pretrained", - "trainable", -] diff --git a/montreal_forced_aligner/aligner/adapting.py b/montreal_forced_aligner/aligner/adapting.py deleted file mode 100644 index 1205039f..00000000 --- a/montreal_forced_aligner/aligner/adapting.py +++ /dev/null @@ -1,269 +0,0 @@ -"""Class definitions for adapting acoustic models""" -from __future__ import annotations - -import os -import shutil -from typing import TYPE_CHECKING, Optional - -from ..abc import Trainer -from ..exceptions import KaldiProcessingError -from ..models import AcousticModel -from ..multiprocessing import ( - align, - calc_fmllr, - compile_information, - compile_train_graphs, - train_map, -) -from ..utils import log_kaldi_errors -from .base import BaseAligner - -if TYPE_CHECKING: - from logging import Logger - - from ..config import AlignConfig - from ..corpus import Corpus - from ..dictionary import MultispeakerDictionary - from ..models import MetaDict - from .pretrained import PretrainedAligner - - -__all__ = ["AdaptingAligner"] - - -class AdaptingAligner(BaseAligner, Trainer): - """ - Aligner adapts another acoustic model to the current data - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - dictionary : :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - Dictionary object for the pronunciation dictionary - pretrained_aligner: :class:`~montreal_forced_aligner.aligner.PretrainedAligner` - Pretrained aligner to use as input to training - align_config : :class:`~montreal_forced_aligner.config.AlignConfig` - Configuration for alignment - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - debug: bool - Flag for debug mode, default is False - verbose: bool - Flag for verbose mode, default is False - logger: :class:`~logging.Logger` - Logger to use - """ - - def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - previous_aligner: PretrainedAligner, - align_config: AlignConfig, - temp_directory: Optional[str] = None, - debug: bool = False, - verbose: bool = False, - logger: Optional[Logger] = None, - ): - self.previous_aligner = previous_aligner - super().__init__( - corpus, - dictionary, - align_config, - temp_directory, - debug, - verbose, - logger, - acoustic_model=self.previous_aligner.acoustic_model, - ) - self.align_config.data_directory = corpus.split_directory - log_dir = os.path.join(self.align_directory, "log") - os.makedirs(log_dir, exist_ok=True) - self.align_config.logger = self.logger - self.logger.info("Done with setup!") - self.training_complete = False - self.mapping_tau = 20 - - def setup(self) -> None: - """Set up the aligner""" - super().setup() - self.previous_aligner.align() - self.acoustic_model.export_model(self.adapt_directory) - for f in ["final.mdl", "final.alimdl"]: - p = os.path.join(self.adapt_directory, f) - if not os.path.exists(p): - continue - os.rename(p, os.path.join(self.adapt_directory, f.replace("final", "0"))) - - @property - def align_directory(self) -> str: - """Align directory""" - return os.path.join(self.temp_directory, "adapted_align") - - @property - def adapt_directory(self) -> str: - """Adapt directory""" - return os.path.join(self.temp_directory, "adapt") - - @property - def working_directory(self) -> str: - """Current working directory""" - if self.training_complete: - return self.align_directory - return self.adapt_directory - - @property - def working_log_directory(self) -> str: - """Current log directory""" - return os.path.join(self.working_directory, "log") - - @property - def current_model_path(self): - """Current acoustic model path""" - if self.training_complete: - return os.path.join(self.working_directory, "final.mdl") - return os.path.join(self.working_directory, "0.mdl") - - @property - def next_model_path(self): - """Next iteration's acoustic model path""" - return os.path.join(self.working_directory, "final.mdl") - - def train(self) -> None: - """Run the adaptation""" - done_path = os.path.join(self.adapt_directory, "done") - dirty_path = os.path.join(self.adapt_directory, "dirty") - if os.path.exists(done_path): - self.logger.info("Adapting already done, skipping.") - return - try: - self.logger.info("Adapting pretrained model...") - train_map(self) - self.training_complete = True - shutil.copyfile( - os.path.join(self.adapt_directory, "final.mdl"), - os.path.join(self.align_directory, "final.mdl"), - ) - shutil.copyfile( - os.path.join(self.adapt_directory, "final.occs"), - os.path.join(self.align_directory, "final.occs"), - ) - shutil.copyfile( - os.path.join(self.adapt_directory, "tree"), - os.path.join(self.align_directory, "tree"), - ) - if os.path.exists(os.path.join(self.adapt_directory, "final.alimdl")): - shutil.copyfile( - os.path.join(self.adapt_directory, "final.alimdl"), - os.path.join(self.align_directory, "final.alimdl"), - ) - if os.path.exists(os.path.join(self.adapt_directory, "lda.mat")): - shutil.copyfile( - os.path.join(self.adapt_directory, "lda.mat"), - os.path.join(self.align_directory, "lda.mat"), - ) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - with open(done_path, "w"): - pass - - @property - def meta(self) -> MetaDict: - """Acoustic model metadata""" - from datetime import datetime - - from ..utils import get_mfa_version - - data = { - "phones": sorted(self.dictionary.config.non_silence_phones), - "version": get_mfa_version(), - "architecture": self.acoustic_model.meta["architecture"], - "train_date": str(datetime.now()), - "features": self.previous_aligner.align_config.feature_config.params(), - "multilingual_ipa": self.dictionary.config.multilingual_ipa, - } - if self.dictionary.config.multilingual_ipa: - data["strip_diacritics"] = self.dictionary.config.strip_diacritics - data["digraphs"] = self.dictionary.config.digraphs - return data - - def save(self, path, root_directory=None) -> None: - """ - Output an acoustic model and dictionary to the specified path - - Parameters - ---------- - path : str - Path to save acoustic model and dictionary - root_directory : str or None - Path for root directory of temporary files - """ - directory, filename = os.path.split(path) - basename, _ = os.path.splitext(filename) - acoustic_model = AcousticModel.empty(basename, root_directory=root_directory) - acoustic_model.add_meta_file(self) - acoustic_model.add_model(self.align_directory) - if directory: - os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(path) - acoustic_model.dump(path) - - def align(self, subset: Optional[int] = None) -> None: - """ - Align using the adapted model - - Parameters - ---------- - subset: int, optional - Number of utterances to align in corpus - """ - done_path = os.path.join(self.align_directory, "done") - dirty_path = os.path.join(self.align_directory, "dirty") - if os.path.exists(done_path): - self.logger.info("Alignment already done, skipping.") - return - try: - log_dir = os.path.join(self.align_directory, "log") - os.makedirs(log_dir, exist_ok=True) - compile_train_graphs(self) - - self.logger.info("Performing first-pass alignment...") - self.speaker_independent = True - align(self) - unaligned, average_log_like = compile_information(self) - self.logger.debug( - f"Prior to SAT, average per frame likelihood (this might not actually mean anything): {average_log_like}" - ) - if ( - not self.align_config.disable_sat - and self.previous_aligner.acoustic_model.feature_config.fmllr - and not os.path.exists(os.path.join(self.align_directory, "trans.0")) - ): - self.logger.info("Calculating fMLLR for speaker adaptation...") - calc_fmllr(self) - - self.speaker_independent = False - self.logger.info("Performing second-pass alignment...") - align(self) - - unaligned, average_log_like = compile_information(self) - self.logger.debug( - f"Following SAT, average per frame likelihood (this might not actually mean anything): {average_log_like}" - ) - - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - with open(done_path, "w"): - pass diff --git a/montreal_forced_aligner/aligner/base.py b/montreal_forced_aligner/aligner/base.py deleted file mode 100644 index 3985eb8c..00000000 --- a/montreal_forced_aligner/aligner/base.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Class definitions for base aligner""" -from __future__ import annotations - -import logging -import os -import shutil -import time -from typing import TYPE_CHECKING, Optional - -from ..abc import Aligner -from ..config import TEMP_DIR -from ..exceptions import KaldiProcessingError -from ..multiprocessing import ( - align, - calc_fmllr, - compile_information, - compile_train_graphs, - convert_ali_to_textgrids, -) -from ..utils import log_kaldi_errors - -if TYPE_CHECKING: - from logging import Logger - - import montreal_forced_aligner - - from ..config import AlignConfig - from ..corpus.base import Corpus - from ..dictionary import MultispeakerDictionary - from ..models import AcousticModel - -__all__ = ["BaseAligner"] - - -class BaseAligner(Aligner): - """ - Base aligner class for common aligner functions - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - dictionary : :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - Dictionary object for the pronunciation dictionary - align_config : :class:`~montreal_forced_aligner.config.AlignConfig` - Configuration for alignment - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - debug : bool - Flag for running in debug mode, defaults to false - verbose : bool - Flag for running in verbose mode, defaults to false - logger : :class:`~logging.Logger` - Logger to use - """ - - def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - align_config: AlignConfig, - temp_directory: Optional[str] = None, - debug: bool = False, - verbose: bool = False, - logger: Optional[Logger] = None, - acoustic_model: Optional[AcousticModel] = None, - ): - super().__init__(corpus, dictionary, align_config) - if not temp_directory: - temp_directory = TEMP_DIR - self.temp_directory = temp_directory - os.makedirs(self.temp_directory, exist_ok=True) - self.log_file = os.path.join(self.temp_directory, "aligner.log") - if logger is None: - self.logger = logging.getLogger("corpus_setup") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - else: - self.logger = logger - self.acoustic_model = None - self.verbose = verbose - self.debug = debug - self.speaker_independent = True - self.uses_cmvn = True - self.uses_splices = False - self.uses_voiced = False - self.iteration = None - self.acoustic_model = acoustic_model - self.setup() - - def setup(self) -> None: - """ - Set up dictionary, corpus and configurations - """ - self.dictionary.set_word_set(self.corpus.word_set) - self.dictionary.write() - self.corpus.initialize_corpus(self.dictionary, self.align_config.feature_config) - self.align_config.silence_csl = self.dictionary.config.silence_csl - self.data_directory = self.corpus.split_directory - self.feature_config = self.align_config.feature_config - - @property - def use_mp(self) -> bool: - """Flag for using multiprocessing""" - return self.align_config.use_mp - - @property - def meta(self) -> montreal_forced_aligner.abc.MetaDict: - """Metadata for the trained model""" - from ..utils import get_mfa_version - - data = { - "phones": sorted(self.dictionary.config.non_silence_phones), - "version": get_mfa_version(), - "architecture": "gmm-hmm", - "features": "mfcc+deltas", - } - return data - - @property - def align_options(self): - """Options for alignment""" - options = self.align_config.align_options - options["optional_silence_csl"] = self.dictionary.config.optional_silence_csl - return options - - @property - def fmllr_options(self): - """Options for fMLLR""" - options = self.align_config.fmllr_options - options["silence_csl"] = self.dictionary.config.silence_csl - return options - - @property - def align_directory(self) -> str: - """Align directory""" - return os.path.join(self.temp_directory, "align") - - @property - def working_directory(self) -> str: - """Current working directory""" - return self.align_directory - - @property - def model_path(self) -> str: - """Current acoustic model path""" - return self.current_model_path - - @property - def current_model_path(self) -> str: - """Current acoustic model path""" - return os.path.join(self.align_directory, "final.mdl") - - @property - def alignment_model_path(self): - """Alignment acoustic model path""" - path = os.path.join(self.working_directory, "final.alimdl") - if self.speaker_independent and os.path.exists(path): - return path - return os.path.join(self.working_directory, "final.mdl") - - @property - def working_log_directory(self) -> str: - """Current log directory""" - return os.path.join(self.align_directory, "log") - - @property - def backup_output_directory(self) -> Optional[str]: - """Backup output directory""" - if self.align_config.overwrite: - return None - return os.path.join(self.align_directory, "textgrids") - - def compile_information(self, output_directory: str) -> None: - """ - Compile information about the quality of alignment - - Parameters - ---------- - output_directory: str - Directory to save information to - """ - issues, average_log_like = compile_information(self) - errors_path = os.path.join(output_directory, "output_errors.txt") - if os.path.exists(errors_path): - self.logger.warning( - "There were errors when generating the textgrids. See the output_errors.txt in the " - "output directory for more details." - ) - if issues: - issue_path = os.path.join(output_directory, "unaligned.txt") - with open(issue_path, "w", encoding="utf8") as f: - for u, r in sorted(issues.items()): - f.write(f"{u}\t{r}\n") - self.logger.warning( - f"There were {len(issues)} segments/files not aligned. Please see {issue_path} for more details on why " - "alignment failed for these files." - ) - if ( - self.backup_output_directory is not None - and os.path.exists(self.backup_output_directory) - and os.listdir(self.backup_output_directory) - ): - self.logger.info( - f"Some TextGrids were not output in the output directory to avoid overwriting existing files. " - f"You can find them in {self.backup_output_directory}, and if you would like to disable this " - f"behavior, you can rerun with the --overwrite flag or run `mfa configure --always_overwrite`." - ) - - def export_textgrids(self, output_directory: str) -> None: - """ - Export a TextGrid file for every sound file in the dataset - - Parameters - ---------- - output_directory: str - Directory to save to - """ - begin = time.time() - self.textgrid_output = output_directory - if self.backup_output_directory is not None and os.path.exists( - self.backup_output_directory - ): - shutil.rmtree(self.backup_output_directory, ignore_errors=True) - convert_ali_to_textgrids(self) - self.compile_information(output_directory) - self.logger.debug(f"Exported TextGrids in a total of {time.time() - begin} seconds") - - def align(self, subset: Optional[int] = None) -> None: - """ - Perform alignment - - Parameters - ---------- - subset: int, optional - Number of utterances to align - """ - done_path = os.path.join(self.align_directory, "done") - dirty_path = os.path.join(self.align_directory, "dirty") - if os.path.exists(done_path): - self.logger.info("Alignment already done, skipping.") - return - try: - compile_train_graphs(self) - log_dir = os.path.join(self.align_directory, "log") - os.makedirs(log_dir, exist_ok=True) - - self.logger.info("Performing first-pass alignment...") - align(self) - _, average_log_like = compile_information(self) - self.logger.debug( - f"Prior to SAT, average per frame likelihood (this might not actually mean anything): {average_log_like}" - ) - if ( - not self.align_config.disable_sat - and self.acoustic_model.feature_config.fmllr - and not os.path.exists(os.path.join(self.align_directory, "trans.0")) - ): - self.logger.info("Calculating fMLLR for speaker adaptation...") - calc_fmllr(self) - self.logger.info("Performing second-pass alignment...") - align(self) - - _, average_log_like = compile_information(self) - self.logger.debug( - f"Following SAT, average per frame likelihood (this might not actually mean anything): {average_log_like}" - ) - - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - with open(done_path, "w"): - pass diff --git a/montreal_forced_aligner/aligner/pretrained.py b/montreal_forced_aligner/aligner/pretrained.py deleted file mode 100644 index be66623f..00000000 --- a/montreal_forced_aligner/aligner/pretrained.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Class definitions for aligning with pretrained acoustic models""" -from __future__ import annotations - -import os -from collections import Counter -from typing import TYPE_CHECKING, Optional - -from ..multiprocessing import generate_pronunciations -from .base import BaseAligner - -if TYPE_CHECKING: - from logging import Logger - - from ..config import AlignConfig - from ..corpus import Corpus - from ..dictionary import MultispeakerDictionary - from ..models import AcousticModel - -__all__ = ["PretrainedAligner"] - - -class PretrainedAligner(BaseAligner): - """ - Class for aligning a dataset using a pretrained acoustic model - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - dictionary : :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - Dictionary object for the pronunciation dictionary - acoustic_model : :class:`~montreal_forced_aligner.models.AcousticModel` - Archive containing the acoustic model and pronunciation dictionary - align_config : :class:`~montreal_forced_aligner.config.AlignConfig` - Configuration for alignment - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - debug: bool - Flag for debug mode, default is False - verbose: bool - Flag for verbose mode, default is False - logger: :class:`~logging.Logger` - Logger to use - """ - - def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - acoustic_model: AcousticModel, - align_config: AlignConfig, - temp_directory: Optional[str] = None, - debug: bool = False, - verbose: bool = False, - logger: Optional[Logger] = None, - ): - super().__init__( - corpus, - dictionary, - align_config, - temp_directory, - debug, - verbose, - logger, - acoustic_model=acoustic_model, - ) - self.data_directory = corpus.split_directory - log_dir = os.path.join(self.align_directory, "log") - os.makedirs(log_dir, exist_ok=True) - self.align_config.logger = self.logger - self.logger.info("Done with setup!") - - @property - def model_directory(self) -> str: - """Model directory""" - return os.path.join(self.temp_directory, "model") - - def setup(self) -> None: - """Set up aligner""" - self.dictionary.config.non_silence_phones = self.acoustic_model.meta["phones"] - super(PretrainedAligner, self).setup() - self.acoustic_model.export_model(self.align_directory) - - @property - def ali_paths(self): - """Alignment archive paths""" - jobs = [x.align_arguments(self) for x in self.corpus.jobs] - ali_paths = [] - for j in jobs: - ali_paths.extend(j.ali_paths.values()) - return ali_paths - - def generate_pronunciations( - self, output_path: str, calculate_silence_probs: bool = False, min_count: int = 1 - ) -> None: - """ - Generate pronunciation probabilities for the dictionary - - Parameters - ---------- - output_path: str - Path to save new dictionary - calculate_silence_probs: bool - Flag for whether to calculate silence probabilities, default is False - min_count: int - Specifies the minimum count of words to include in derived probabilities, default is 1 - """ - pron_counts, utt_mapping = generate_pronunciations(self) - for dict_name, dictionary in self.dictionary.dictionary_mapping.items(): - counts = pron_counts[dict_name] - mapping = utt_mapping[dict_name] - if calculate_silence_probs: - sil_before_counts = Counter() - nonsil_before_counts = Counter() - sil_after_counts = Counter() - nonsil_after_counts = Counter() - sils = ["", "", ""] - for v in mapping.values(): - for i, w in enumerate(v): - if w in sils: - continue - prev_w = v[i - 1] - next_w = v[i + 1] - if prev_w in sils: - sil_before_counts[w] += 1 - else: - nonsil_before_counts[w] += 1 - if next_w in sils: - sil_after_counts[w] += 1 - else: - nonsil_after_counts[w] += 1 - - dictionary.pronunciation_probabilities = True - for word, prons in dictionary.words.items(): - if word not in counts: - for p in prons: - p["probability"] = 1 - else: - total = 0 - best_pron = 0 - best_count = 0 - for p in prons: - p["probability"] = min_count - if p["pronunciation"] in counts[word]: - p["probability"] += counts[word][p["pronunciation"]] - total += p["probability"] - if p["probability"] > best_count: - best_pron = p["pronunciation"] - best_count = p["probability"] - for p in prons: - if p["pronunciation"] == best_pron: - p["probability"] = 1 - else: - p["probability"] /= total - dictionary.words[word] = prons - dictionary.export_lexicon(output_path, probability=True) diff --git a/montreal_forced_aligner/aligner/trainable.py b/montreal_forced_aligner/aligner/trainable.py deleted file mode 100644 index faf47659..00000000 --- a/montreal_forced_aligner/aligner/trainable.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Class definitions for trainable aligners""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -from ..abc import Trainer -from .base import BaseAligner - -if TYPE_CHECKING: - from logging import Logger - - from ..aligner.pretrained import PretrainedAligner - from ..config import AlignConfig, TrainingConfig - from ..corpus import Corpus - from ..dictionary import MultispeakerDictionary - -__all__ = ["TrainableAligner"] - - -class TrainableAligner(BaseAligner, Trainer): - """ - Aligner that aligns and trains acoustics models on a large dataset - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - dictionary : :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - Dictionary object for the pronunciation dictionary - training_config : :class:`~montreal_forced_aligner.config.TrainingConfig` - Configuration to train a model - align_config : :class:`~montreal_forced_aligner.config.AlignConfig` - Configuration for alignment - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - debug: bool - Flag for debug mode, default is False - verbose: bool - Flag for verbose mode, default is False - logger: :class:`~logging.Logger` - Logger to use - pretrained_aligner: :class:`~montreal_forced_aligner.aligner.pretrained.PretrainedAligner`, optional - Pretrained aligner to use as input to training - """ - - def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - training_config: TrainingConfig, - align_config: AlignConfig, - temp_directory: Optional[str] = None, - debug: bool = False, - verbose: bool = False, - logger: Optional[Logger] = None, - pretrained_aligner: Optional[PretrainedAligner] = None, - ): - self.training_config = training_config - self.pretrained_aligner = pretrained_aligner - if self.pretrained_aligner is not None: - acoustic_model = pretrained_aligner.acoustic_model - else: - acoustic_model = None - super(TrainableAligner, self).__init__( - corpus, - dictionary, - align_config, - temp_directory, - debug, - verbose, - logger, - acoustic_model=acoustic_model, - ) - for trainer in self.training_config.training_configs: - trainer.logger = self.logger - - def save(self, path: str, root_directory: Optional[str] = None) -> None: - """ - Output an acoustic model and dictionary to the specified path - - Parameters - ---------- - path : str - Path to save acoustic model and dictionary - root_directory : str or None - Path for root directory of temporary files - """ - self.training_config.values()[-1].save(path, root_directory) - self.logger.info(f"Saved model to {path}") - - @property - def meta(self) -> dict: - """Acoustic model parameters""" - from ..utils import get_mfa_version - - data = { - "phones": sorted(self.dictionary.config.non_silence_phones), - "version": get_mfa_version(), - "architecture": self.training_config.values()[-1].architecture, - "phone_type": self.training_config.values()[-1].phone_type, - "features": self.align_config.feature_config.params(), - } - return data - - @property - def model_path(self) -> str: - return self.training_config.values()[-1].model_path - - def train(self, generate_final_alignments: bool = True) -> None: - """ - Run through the training configurations to produce a final acoustic model - - Parameters - ---------- - generate_final_alignments: bool - Flag for whether final alignments should be generated at the end of training, defaults to True - """ - previous = self.pretrained_aligner - for identifier, trainer in self.training_config.items(): - trainer.debug = self.debug - trainer.logger = self.logger - if previous is not None: - previous.align(trainer.subset) - trainer.init_training( - identifier, self.temp_directory, self.corpus, self.dictionary, previous - ) - trainer.train() - previous = trainer - if generate_final_alignments: - previous.align(None) - - @property - def align_directory(self) -> str: - """Align directory""" - return self.training_config.values()[-1].align_directory diff --git a/montreal_forced_aligner/alignment/__init__.py b/montreal_forced_aligner/alignment/__init__.py new file mode 100644 index 00000000..58d39d52 --- /dev/null +++ b/montreal_forced_aligner/alignment/__init__.py @@ -0,0 +1,22 @@ +""" +Aligners +======== + +""" +from montreal_forced_aligner.alignment.adapting import AdaptingAligner +from montreal_forced_aligner.alignment.base import CorpusAligner +from montreal_forced_aligner.alignment.mixins import AlignMixin +from montreal_forced_aligner.alignment.pretrained import DictionaryTrainer, PretrainedAligner + +__all__ = [ + "AdaptingAligner", + "PretrainedAligner", + "CorpusAligner", + "DictionaryTrainer", + "adapting", + "base", + "pretrained", + "mixins", + "AlignMixin", + "multiprocessing", +] diff --git a/montreal_forced_aligner/alignment/adapting.py b/montreal_forced_aligner/alignment/adapting.py new file mode 100644 index 00000000..6f9fa26a --- /dev/null +++ b/montreal_forced_aligner/alignment/adapting.py @@ -0,0 +1,395 @@ +"""Class definitions for adapting acoustic models""" +from __future__ import annotations + +import os +import shutil +import subprocess +import time +from typing import TYPE_CHECKING, NamedTuple + +from montreal_forced_aligner.abc import AdapterMixin +from montreal_forced_aligner.alignment.pretrained import PretrainedAligner +from montreal_forced_aligner.exceptions import KaldiProcessingError +from montreal_forced_aligner.models import AcousticModel +from montreal_forced_aligner.utils import log_kaldi_errors, run_mp, run_non_mp, thirdparty_binary + +if TYPE_CHECKING: + from montreal_forced_aligner.models import MetaDict + + +__all__ = ["AdaptingAligner"] + + +class MapAccStatsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.alignment.adapting.map_acc_stats_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + model_path: str + ali_paths: dict[str, str] + acc_paths: dict[str, str] + + +def map_acc_stats_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + model_path: str, + ali_paths: dict[str, str], + acc_paths: dict[str, str], +) -> None: + """ + Multiprocessing function for accumulating mapped stats for adapting acoustic models to new + domains + + See Also + -------- + :meth:`.AdaptingAligner.train_map` + Main function that calls this function in parallel + :meth:`.AdaptingAligner.map_acc_stats_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-acc-stats-ali` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + model_path: str + Path to the acoustic model file + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + acc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + acc_path = acc_paths[dict_name] + ali_path = ali_paths[dict_name] + acc_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-acc-stats-ali"), + model_path, + feature_string, + f"ark,s,cs:{ali_path}", + acc_path, + ], + stderr=log_file, + env=os.environ, + ) + acc_proc.communicate() + + +class AdaptingAligner(PretrainedAligner, AdapterMixin): + """ + Adapt an acoustic model to a new dataset + + Parameters + ---------- + mapping_tau: int + Tau to use in mapping stats between new domain data and pretrained model + + See Also + -------- + :class:`~montreal_forced_aligner.alignment.pretrained.PretrainedAligner` + For dictionary, corpus, and alignment parameters + :class:`~montreal_forced_aligner.abc.AdapterMixin` + For adapting parameters + + Attributes + ---------- + initialized: bool + Flag for whether initialization is complete + adaptation_done: bool + Flag for whether adaptation is complete + """ + + def __init__(self, mapping_tau: int = 20, **kwargs): + super().__init__(**kwargs) + self.mapping_tau = mapping_tau + self.initialized = False + self.adaptation_done = False + + def map_acc_stats_arguments(self, alignment=False) -> list[MapAccStatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.adapting.map_acc_stats_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.adapting.MapAccStatsArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + if alignment: + model_path = self.alignment_model_path + else: + model_path = self.model_path + return [ + MapAccStatsArguments( + os.path.join(self.working_log_directory, f"map_acc_stats.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + model_path, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.working_directory, "map", "acc"), + ) + for j in self.jobs + ] + + @property + def workflow_identifier(self) -> str: + """Adaptation identifier""" + return "adapt_acoustic_model" + + @property + def align_directory(self) -> str: + """Align directory""" + return os.path.join(self.output_directory, "adapted_align") + + @property + def working_directory(self) -> str: + """Current working directory""" + if self.adaptation_done: + return self.align_directory + return self.workflow_directory + + @property + def working_log_directory(self) -> str: + """Current log directory""" + return os.path.join(self.working_directory, "log") + + @property + def model_path(self): + """Current acoustic model path""" + if not self.adaptation_done: + return os.path.join(self.working_directory, "0.mdl") + return os.path.join(self.working_directory, "final.mdl") + + @property + def next_model_path(self): + """Mapped acoustic model path""" + return os.path.join(self.working_directory, "final.mdl") + + def train_map(self) -> None: + """ + Trains an adapted acoustic model through mapping model states and update those with + enough data. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.adapting.map_acc_stats_func` + Multiprocessing helper function for each job + :meth:`.AdaptingAligner.map_acc_stats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`gmm-sum-accs` + Relevant Kaldi binary + :kaldi_src:`gmm-ismooth-stats` + Relevant Kaldi binary + :kaldi_src:`gmm-est` + Relevant Kaldi binary + :kaldi_steps:`train_map` + Reference Kaldi script + + """ + begin = time.time() + initial_mdl_path = os.path.join(self.working_directory, "0.mdl") + final_mdl_path = os.path.join(self.working_directory, "final.mdl") + log_directory = self.working_log_directory + os.makedirs(log_directory, exist_ok=True) + + jobs = self.map_acc_stats_arguments() + if self.use_mp: + run_mp(map_acc_stats_func, jobs, log_directory) + else: + run_non_mp(map_acc_stats_func, jobs, log_directory) + log_path = os.path.join(self.working_log_directory, "map_model_est.log") + occs_path = os.path.join(self.working_directory, "final.occs") + with open(log_path, "w", encoding="utf8") as log_file: + acc_files = [] + for j in jobs: + acc_files.extend(j.acc_paths.values()) + sum_proc = subprocess.Popen( + [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + ismooth_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-ismooth-stats"), + "--smooth-from-model", + f"--tau={self.mapping_tau}", + initial_mdl_path, + "-", + "-", + ], + stderr=log_file, + stdin=sum_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est"), + "--update-flags=m", + f"--write-occs={occs_path}", + "--remove-low-count-gaussians=false", + initial_mdl_path, + "-", + final_mdl_path, + ], + stdin=ismooth_proc.stdout, + stderr=log_file, + env=os.environ, + ) + est_proc.communicate() + if self.uses_speaker_adaptation: + initial_alimdl_path = os.path.join(self.working_directory, "0.alimdl") + final_alimdl_path = os.path.join(self.working_directory, "0.alimdl") + if os.path.exists(initial_alimdl_path): + self.speaker_independent = True + jobs = self.map_acc_stats_arguments(alignment=True) + if self.use_mp: + run_mp(map_acc_stats_func, jobs, log_directory) + else: + run_non_mp(map_acc_stats_func, jobs, log_directory) + + log_path = os.path.join(self.working_log_directory, "map_model_est.log") + with open(log_path, "w", encoding="utf8") as log_file: + acc_files = [] + for j in jobs: + acc_files.extend(j.acc_paths) + sum_proc = subprocess.Popen( + [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + ismooth_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-ismooth-stats"), + "--smooth-from-model", + f"--tau={self.mapping_tau}", + initial_alimdl_path, + "-", + "-", + ], + stderr=log_file, + stdin=sum_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est"), + "--update-flags=m", + "--remove-low-count-gaussians=false", + initial_alimdl_path, + "-", + final_alimdl_path, + ], + stdin=ismooth_proc.stdout, + stderr=log_file, + env=os.environ, + ) + est_proc.communicate() + + self.logger.debug(f"Mapping models took {time.time() - begin}") + + def adapt(self) -> None: + """Run the adaptation""" + self.setup() + dirty_path = os.path.join(self.working_directory, "dirty") + done_path = os.path.join(self.working_directory, "done") + if os.path.exists(done_path): + self.logger.info("Adaptation already done, skipping.") + return + self.logger.info("Generating initial alignments...") + for f in ["final.mdl", "final.alimdl"]: + p = os.path.join(self.working_directory, f) + if not os.path.exists(p): + continue + os.rename(p, os.path.join(self.working_directory, f.replace("final", "0"))) + self.align() + os.makedirs(self.align_directory, exist_ok=True) + try: + self.logger.info("Adapting pretrained model...") + self.train_map() + self.export_model(os.path.join(self.working_log_directory, "acoustic_model.zip")) + shutil.copyfile( + os.path.join(self.working_directory, "final.mdl"), + os.path.join(self.align_directory, "final.mdl"), + ) + shutil.copyfile( + os.path.join(self.working_directory, "final.occs"), + os.path.join(self.align_directory, "final.occs"), + ) + shutil.copyfile( + os.path.join(self.working_directory, "tree"), + os.path.join(self.align_directory, "tree"), + ) + if os.path.exists(os.path.join(self.working_directory, "final.alimdl")): + shutil.copyfile( + os.path.join(self.working_directory, "final.alimdl"), + os.path.join(self.align_directory, "final.alimdl"), + ) + if os.path.exists(os.path.join(self.working_directory, "lda.mat")): + shutil.copyfile( + os.path.join(self.working_directory, "lda.mat"), + os.path.join(self.align_directory, "lda.mat"), + ) + self.adaptation_done = True + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + with open(done_path, "w"): + pass + + @property + def meta(self) -> MetaDict: + """Acoustic model metadata""" + from datetime import datetime + + from ..utils import get_mfa_version + + data = { + "phones": sorted(self.non_silence_phones), + "version": get_mfa_version(), + "architecture": self.acoustic_model.meta["architecture"], + "train_date": str(datetime.now()), + "features": self.feature_options, + "multilingual_ipa": self.multilingual_ipa, + } + if self.multilingual_ipa: + data["strip_diacritics"] = self.strip_diacritics + data["digraphs"] = self.digraphs + return data + + def export_model(self, output_model_path: str) -> None: + """ + Output an acoustic model to the specified path + + Parameters + ---------- + output_model_path : str + Path to save adapted acoustic model + """ + directory, filename = os.path.split(output_model_path) + basename, _ = os.path.splitext(filename) + acoustic_model = AcousticModel.empty(basename, root_directory=self.working_log_directory) + acoustic_model.add_meta_file(self) + acoustic_model.add_model(self.align_directory) + if directory: + os.makedirs(directory, exist_ok=True) + basename, _ = os.path.splitext(output_model_path) + acoustic_model.dump(output_model_path) diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py new file mode 100644 index 00000000..55b0b6e1 --- /dev/null +++ b/montreal_forced_aligner/alignment/base.py @@ -0,0 +1,537 @@ +"""Class definitions for base aligner""" +from __future__ import annotations + +import multiprocessing as mp +import os +import shutil +import sys +import time +import traceback +from typing import Optional + +from montreal_forced_aligner.abc import FileExporterMixin +from montreal_forced_aligner.alignment.mixins import AlignMixin +from montreal_forced_aligner.alignment.multiprocessing import ( + AliToCtmArguments, + CleanupWordCtmArguments, + CleanupWordCtmProcessWorker, + CombineCtmArguments, + CombineProcessWorker, + ExportPreparationProcessWorker, + ExportTextGridArguments, + ExportTextGridProcessWorker, + NoCleanupWordCtmArguments, + NoCleanupWordCtmProcessWorker, + PhoneCtmArguments, + PhoneCtmProcessWorker, + ali_to_ctm_func, +) +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusPronunciationMixin +from montreal_forced_aligner.exceptions import AlignmentExportError +from montreal_forced_aligner.textgrid import ( + ctm_to_textgrid, + output_textgrid_writing_errors, + parse_from_phone, + parse_from_word, + parse_from_word_no_cleanup, + process_ctm_line, +) +from montreal_forced_aligner.utils import Stopped, run_mp, run_non_mp + +__all__ = ["CorpusAligner"] + + +class CorpusAligner(AcousticCorpusPronunciationMixin, AlignMixin, FileExporterMixin): + """ + Mixin class that aligns corpora with pronunciation dictionaries + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin` + For dictionary and corpus parsing parameters + :class:`~montreal_forced_aligner.alignment.mixins.AlignMixin` + For alignment parameters + :class:`~montreal_forced_aligner.abc.FileExporterMixin` + For file exporting parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def cleanup_word_ctm_arguments(self) -> list[CleanupWordCtmArguments]: + """ + Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CleanupWordCtmProcessWorker` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.CleanupWordCtmArguments`] + Arguments for processing + """ + return [ + CleanupWordCtmArguments( + j.construct_path_dictionary(self.working_directory, "word", "ctm"), + j.current_dictionary_names, + j.job_utts(), + j.dictionary_data(), + ) + for j in self.jobs + ] + + def no_cleanup_word_ctm_arguments(self) -> list[NoCleanupWordCtmArguments]: + """ + Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.NoCleanupWordCtmProcessWorker` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.NoCleanupWordCtmArguments`] + Arguments for processing + """ + return [ + NoCleanupWordCtmArguments( + j.construct_path_dictionary(self.working_directory, "word", "ctm"), + j.current_dictionary_names, + j.job_utts(), + j.dictionary_data(), + ) + for j in self.jobs + ] + + def phone_ctm_arguments(self) -> list[PhoneCtmArguments]: + """ + Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmProcessWorker` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmArguments`] + Arguments for processing + """ + return [ + PhoneCtmArguments( + j.construct_path_dictionary(self.working_directory, "phone", "ctm"), + j.current_dictionary_names, + j.job_utts(), + j.reversed_phone_mappings(), + j.positions(), + ) + for j in self.jobs + ] + + def combine_ctm_arguments(self) -> list[CombineCtmArguments]: + """ + Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CombineProcessWorker` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.CombineCtmArguments`] + Arguments for processing + """ + return [ + CombineCtmArguments( + j.current_dictionary_names, + j.job_files(), + j.job_speakers(), + j.dictionary_data(), + self.cleanup_textgrids, + ) + for j in self.jobs + ] + + def export_textgrid_arguments(self) -> list[ExportTextGridArguments]: + """ + Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridArguments`] + Arguments for processing + """ + return [ + ExportTextGridArguments( + self.files, + self.frame_shift, + self.textgrid_output, + self.backup_output_directory, + ) + for _ in self.jobs + ] + + @property + def backup_output_directory(self) -> Optional[str]: + """Backup directory if overwriting is not allowed""" + return None + + def ctms_to_textgrids_mp(self): + """ + Multiprocessing function for exporting alignment CTM information as TextGrids + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.ali_to_ctm_func` + Multiprocessing helper function for converting ali archives to CTM format + :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmProcessWorker` + Multiprocessing helper class for processing CTM files + :meth:`.CorpusAligner.phone_ctm_arguments` + Job method for generating arguments for PhoneCtmProcessWorker + :class:`~montreal_forced_aligner.alignment.multiprocessing.CleanupWordCtmProcessWorker` + Multiprocessing helper class for processing CTM files + :meth:`.CorpusAligner.cleanup_word_ctm_arguments` + Job method for generating arguments for CleanupWordCtmProcessWorker + :class:`~montreal_forced_aligner.alignment.multiprocessing.NoCleanupWordCtmProcessWorker` + Multiprocessing helper class for processing CTM files + :meth:`.CorpusAligner.no_cleanup_word_ctm_arguments` + Job method for generating arguments for NoCleanupWordCtmProcessWorker + :class:`~montreal_forced_aligner.alignment.multiprocessing.CombineProcessWorker` + Multiprocessing helper class for combining word and phone alignments + :meth:`.CorpusAligner.combine_ctm_arguments` + Job method for generating arguments for NoCleanupWordCtmProcessWorker + :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportPreparationProcessWorker` + Multiprocessing helper class for generating TextGrid tiers + :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker` + Multiprocessing helper class for exporting TextGrid files + :meth:`.CorpusAligner.export_textgrid_arguments` + Job method for generating arguments for ExportTextGridProcessWorker + :kaldi_steps:`get_train_ctm` + Reference Kaldi script + + """ + export_begin = time.time() + manager = mp.Manager() + textgrid_errors = manager.dict() + error_catching = manager.dict() + stopped = Stopped() + if not self.overwrite: + os.makedirs(self.backup_output_directory, exist_ok=True) + + self.logger.debug("Beginning to process ctm files...") + ctm_begin_time = time.time() + word_procs = [] + phone_procs = [] + combine_procs = [] + finished_signals = [Stopped() for _ in range(self.num_jobs)] + finished_processing = Stopped() + to_process_queue = [mp.JoinableQueue() for _ in range(self.num_jobs)] + to_export_queue = mp.JoinableQueue() + for_write_queue = mp.JoinableQueue() + finished_combining = Stopped() + + if self.cleanup_textgrids: + word_ctm_args = self.cleanup_word_ctm_arguments() + else: + word_ctm_args = self.no_cleanup_word_ctm_arguments() + phone_ctm_args = self.phone_ctm_arguments() + combine_ctm_args = self.combine_ctm_arguments() + export_args = self.export_textgrid_arguments() + for j in self.jobs: + if self.cleanup_textgrids: + word_p = CleanupWordCtmProcessWorker( + j.name, + to_process_queue[j.name], + stopped, + error_catching, + word_ctm_args[j.name], + ) + else: + word_p = NoCleanupWordCtmProcessWorker( + j.name, + to_process_queue[j.name], + stopped, + error_catching, + word_ctm_args[j.name], + ) + + word_procs.append(word_p) + word_p.start() + + phone_p = PhoneCtmProcessWorker( + j.name, + to_process_queue[j.name], + stopped, + error_catching, + phone_ctm_args[j.name], + ) + phone_p.start() + phone_procs.append(phone_p) + + combine_p = CombineProcessWorker( + j.name, + to_process_queue[j.name], + to_export_queue, + stopped, + finished_signals[j.name], + error_catching, + combine_ctm_args[j.name], + ) + combine_p.start() + combine_procs.append(combine_p) + preparation_proc = ExportPreparationProcessWorker( + to_export_queue, for_write_queue, stopped, finished_combining, self.files + ) + preparation_proc.start() + + export_procs = [] + for j in self.jobs: + export_proc = ExportTextGridProcessWorker( + for_write_queue, + stopped, + finished_processing, + textgrid_errors, + export_args[j.name], + ) + export_proc.start() + export_procs.append(export_proc) + + self.logger.debug("Waiting for processes to finish...") + for i in range(self.num_jobs): + word_procs[i].join() + phone_procs[i].join() + finished_signals[i].stop() + + self.logger.debug(f"Ctm parsers took {time.time() - ctm_begin_time} seconds") + + self.logger.debug("Waiting for processes to finish...") + for i in range(self.num_jobs): + to_process_queue[i].join() + combine_procs[i].join() + finished_combining.stop() + + self.logger.debug(f"Combiners took {time.time() - ctm_begin_time} seconds") + self.logger.debug("Beginning export...") + + to_export_queue.join() + preparation_proc.join() + + self.logger.debug(f"Adding jobs for export took {time.time() - export_begin}") + self.logger.debug("Waiting for export processes to join...") + + finished_processing.stop() + for i in range(self.num_jobs): + export_procs[i].join() + for_write_queue.join() + self.logger.debug(f"Export took {time.time() - export_begin} seconds") + + if error_catching: + self.logger.error("Error was encountered in processing CTMs") + for key, error in error_catching.items(): + self.logger.error(f"{key}:\n\n{error}") + raise AlignmentExportError(error_catching) + + if textgrid_errors: + self.logger.warning( + f"There were {len(textgrid_errors)} errors encountered in generating TextGrids. " + f"Check the output_errors.txt file in {os.path.join(self.textgrid_output)} " + f"for more details" + ) + output_textgrid_writing_errors(self.textgrid_output, textgrid_errors) + + def convert_ali_to_textgrids(self) -> None: + """ + Multiprocessing function that aligns based on the current model. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.ali_to_ctm_func` + Multiprocessing helper function for each job + :meth:`.CorpusAligner.ali_to_word_ctm_arguments` + Job method for generating arguments for this function + :meth:`.CorpusAligner.ali_to_phone_ctm_arguments` + Job method for generating arguments for this function + :kaldi_steps:`get_train_ctm` + Reference Kaldi script + """ + log_directory = self.working_log_directory + os.makedirs(self.textgrid_output, exist_ok=True) + jobs = self.ali_to_word_ctm_arguments() # Word CTM jobs + jobs += self.ali_to_phone_ctm_arguments() # Phone CTM jobs + self.logger.info("Generating CTMs from alignment...") + if self.use_mp: + run_mp(ali_to_ctm_func, jobs, log_directory) + else: + run_non_mp(ali_to_ctm_func, jobs, log_directory) + self.logger.info("Finished generating CTMs!") + + self.logger.info("Exporting TextGrids from CTMs...") + if self.use_mp: + self.ctms_to_textgrids_mp() + else: + self.ctms_to_textgrids_non_mp() + self.logger.info("Finished exporting TextGrids!") + + def ctms_to_textgrids_non_mp(self) -> None: + """ + Parse CTM files to TextGrids without using multiprocessing + """ + + def process_current_word_labels(): + """Process the current stack of word labels""" + speaker = cur_utt.speaker + + text = cur_utt.text.split() + if self.cleanup_textgrids: + actual_labels = parse_from_word(current_labels, text, speaker.dictionary_data) + else: + actual_labels = parse_from_word_no_cleanup( + current_labels, speaker.dictionary_data.reversed_words_mapping + ) + cur_utt.word_labels = actual_labels + + def process_current_phone_labels(): + """Process the current stack of phone labels""" + speaker = cur_utt.speaker + + cur_utt.phone_labels = parse_from_phone( + current_labels, + speaker.dictionary.reversed_phone_mapping, + speaker.dictionary.positions, + ) + + export_errors = {} + if self.cleanup_textgrids: + w_args = self.cleanup_word_ctm_arguments() + else: + w_args = self.no_cleanup_word_ctm_arguments() + p_args = self.phone_ctm_arguments() + for j in self.jobs: + + word_arguments = w_args[j.name] + phone_arguments = p_args[j.name] + self.logger.debug(f"Parsing ctms for job {j.name}...") + cur_utt = None + current_labels = [] + for dict_name in word_arguments.dictionaries: + with open(word_arguments.ctm_paths[dict_name], "r") as f: + for line in f: + line = line.strip() + if line == "": + continue + ctm_interval = process_ctm_line(line) + utt = self.utterances[ctm_interval.utterance] + if cur_utt is None: + cur_utt = utt + if utt.is_segment: + utt_begin = utt.begin + else: + utt_begin = 0 + if utt != cur_utt: + process_current_word_labels() + cur_utt = utt + current_labels = [] + + ctm_interval.shift_times(utt_begin) + current_labels.append(ctm_interval) + if current_labels: + process_current_word_labels() + cur_utt = None + current_labels = [] + for dict_name in phone_arguments.dictionaries: + with open(phone_arguments.ctm_paths[dict_name], "r") as f: + for line in f: + line = line.strip() + if line == "": + continue + ctm_interval = process_ctm_line(line) + utt = self.utterances[ctm_interval.utterance] + if cur_utt is None: + cur_utt = utt + if utt.is_segment: + utt_begin = utt.begin + else: + utt_begin = 0 + if utt != cur_utt and cur_utt is not None: + process_current_phone_labels() + cur_utt = utt + current_labels = [] + + ctm_interval.shift_times(utt_begin) + current_labels.append(ctm_interval) + if current_labels: + process_current_phone_labels() + + self.logger.debug(f"Generating TextGrids for job {j.name}...") + processed_files = set() + for file in j.job_files().values(): + first_file_write = True + if file.name in processed_files: + first_file_write = False + try: + ctm_to_textgrid(file, self, first_file_write) + processed_files.add(file.name) + except Exception: + if self.debug: + raise + exc_type, exc_value, exc_traceback = sys.exc_info() + export_errors[file.name] = "\n".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + if export_errors: + self.logger.warning( + f"There were {len(export_errors)} errors encountered in generating TextGrids. " + f"Check the output_errors.txt file in {os.path.join(self.textgrid_output)} " + f"for more details" + ) + output_textgrid_writing_errors(self.textgrid_output, export_errors) + + def export_files(self, output_directory: str) -> None: + """ + Export a TextGrid file for every sound file in the dataset + + Parameters + ---------- + output_directory: str + Directory to save to + """ + begin = time.time() + self.textgrid_output = output_directory + if self.backup_output_directory is not None and os.path.exists( + self.backup_output_directory + ): + shutil.rmtree(self.backup_output_directory, ignore_errors=True) + self.convert_ali_to_textgrids() + self.logger.debug(f"Exported TextGrids in a total of {time.time() - begin} seconds") + + def ali_to_word_ctm_arguments(self) -> list[AliToCtmArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.ali_to_ctm_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmArguments`] + Arguments for processing + """ + return [ + AliToCtmArguments( + os.path.join(self.working_log_directory, f"get_word_ctm.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + j.word_boundary_int_files(), + round(self.frame_shift / 1000, 4), + self.alignment_model_path, + j.construct_path_dictionary(self.working_directory, "word", "ctm"), + True, + ) + for j in self.jobs + ] + + def ali_to_phone_ctm_arguments(self) -> list[AliToCtmArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.ali_to_ctm_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AliToCtmArguments`] + Arguments for processing + """ + return [ + AliToCtmArguments( + os.path.join(self.working_log_directory, f"get_phone_ctm.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + j.word_boundary_int_files(), + round(self.frame_shift / 1000, 4), + self.alignment_model_path, + j.construct_path_dictionary(self.working_directory, "phone", "ctm"), + False, + ) + for j in self.jobs + ] diff --git a/montreal_forced_aligner/alignment/mixins.py b/montreal_forced_aligner/alignment/mixins.py new file mode 100644 index 00000000..758555bd --- /dev/null +++ b/montreal_forced_aligner/alignment/mixins.py @@ -0,0 +1,357 @@ +"""Class definitions for alignment mixins""" +from __future__ import annotations + +import logging +import os +import time +from abc import abstractmethod +from typing import TYPE_CHECKING + +from montreal_forced_aligner.alignment.multiprocessing import ( + AlignArguments, + CompileInformationArguments, + CompileTrainGraphsArguments, + align_func, + compile_information_func, + compile_train_graphs_func, +) +from montreal_forced_aligner.dictionary.mixins import DictionaryMixin +from montreal_forced_aligner.exceptions import AlignmentError +from montreal_forced_aligner.utils import run_mp, run_non_mp + +if TYPE_CHECKING: + from montreal_forced_aligner.abc import MetaDict + from montreal_forced_aligner.corpus.multiprocessing import Job + + +class AlignMixin(DictionaryMixin): + """ + Configuration object for alignment + + Parameters + ---------- + transition_scale : float + Transition scale, defaults to 1.0 + acoustic_scale : float + Acoustic scale, defaults to 0.1 + self_loop_scale : float + Self-loop scale, defaults to 0.1 + boost_silence : float + Factor to boost silence probabilities, 1.0 is no boost or reduction + beam : int + Size of the beam to use in decoding, defaults to 10 + retry_beam : int + Size of the beam to use in decoding if it fails with the initial beam width, defaults to 40 + + + See Also + -------- + :class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin` + For dictionary parsing parameters + + Attributes + ---------- + logger: logging.Logger + Eventual top-level worker logger + jobs: list[Job] + Jobs to process + use_mp: bool + Flag for using multiprocessing + """ + + logger: logging.Logger + jobs: list[Job] + use_mp: bool + + def __init__( + self, + transition_scale: float = 1.0, + acoustic_scale: float = 0.1, + self_loop_scale: float = 0.1, + boost_silence: float = 1.0, + beam: int = 10, + retry_beam: int = 40, + **kwargs, + ): + super().__init__(**kwargs) + self.transition_scale = transition_scale + self.acoustic_scale = acoustic_scale + self.self_loop_scale = self_loop_scale + self.boost_silence = boost_silence + self.beam = beam + self.retry_beam = retry_beam + if self.retry_beam <= self.beam: + self.retry_beam = self.beam * 4 + + @property + def tree_path(self): + """Path to tree file""" + return os.path.join(self.working_directory, "tree") + + @property + @abstractmethod + def data_directory(self): + """Corpus data directory""" + ... + + @abstractmethod + def construct_feature_proc_strings(self) -> list[dict[str, str]]: + """Generate feature strings""" + ... + + def compile_train_graphs_arguments(self) -> list[CompileTrainGraphsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_train_graphs_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsArguments`] + Arguments for processing + """ + args = [] + for j in self.jobs: + lexicon_fst_paths = { + dictionary.name: dictionary.lexicon_fst_path + for dictionary in j.current_dictionaries + } + model_path = self.model_path + if not os.path.exists(model_path): + model_path = self.alignment_model_path + args.append( + CompileTrainGraphsArguments( + os.path.join(self.working_log_directory, f"compile_train_graphs.{j.name}.log"), + j.current_dictionary_names, + os.path.join(self.working_directory, "tree"), + model_path, + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + self.disambiguation_symbols_int_path, + lexicon_fst_paths, + j.construct_path_dictionary(self.working_directory, "fsts", "scp"), + ) + ) + return args + + def align_arguments(self) -> list[AlignArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.align_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AlignArguments`] + Arguments for processing + """ + args = [] + feat_strings = self.construct_feature_proc_strings() + iteration = getattr(self, "iteration", None) + for j in self.jobs: + if iteration is not None: + log_path = os.path.join( + self.working_log_directory, f"align.{iteration}.{j.name}.log" + ) + else: + log_path = os.path.join(self.working_log_directory, f"align.{j.name}.log") + args.append( + AlignArguments( + log_path, + j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "fsts", "scp"), + feat_strings[j.name], + self.alignment_model_path, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + self.align_options, + ) + ) + return args + + def compile_information_arguments(self) -> list[CompileInformationArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_information_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.multiprocessing.CompileInformationArguments`] + Arguments for processing + """ + args = [] + iteration = getattr(self, "iteration", None) + for j in self.jobs: + if iteration is not None: + log_path = os.path.join( + self.working_log_directory, f"align.{iteration}.{j.name}.log" + ) + else: + log_path = os.path.join(self.working_log_directory, f"align.{j.name}.log") + args.append(CompileInformationArguments(log_path)) + return args + + @property + def align_options(self) -> MetaDict: + """Options for use in aligning""" + return { + "transition_scale": self.transition_scale, + "acoustic_scale": self.acoustic_scale, + "self_loop_scale": self.self_loop_scale, + "beam": self.beam, + "retry_beam": self.retry_beam, + "boost_silence": self.boost_silence, + "optional_silence_csl": self.optional_silence_csl, + } + + def alignment_configuration(self) -> MetaDict: + """Configuration parameters""" + return { + "transition_scale": self.transition_scale, + "acoustic_scale": self.acoustic_scale, + "self_loop_scale": self.self_loop_scale, + "boost_silence": self.boost_silence, + "beam": self.beam, + "retry_beam": self.retry_beam, + } + + def compile_train_graphs(self) -> None: + """ + Multiprocessing function that compiles training graphs for utterances. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_train_graphs_func` + Multiprocessing helper function for each job + :meth:`.AlignMixin.compile_train_graphs_arguments` + Job method for generating arguments for the helper function + :kaldi_steps:`align_si` + Reference Kaldi script + :kaldi_steps:`align_fmllr` + Reference Kaldi script + """ + self.logger.debug("Compiling training graphs...") + begin = time.time() + log_directory = self.working_log_directory + os.makedirs(log_directory, exist_ok=True) + jobs = self.compile_train_graphs_arguments() + if self.use_mp: + run_mp(compile_train_graphs_func, jobs, log_directory) + else: + run_non_mp(compile_train_graphs_func, jobs, log_directory) + self.logger.debug(f"Compiling training graphs took {time.time() - begin}") + + def align_utterances(self) -> None: + """ + Multiprocessing function that aligns based on the current model. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.align_func` + Multiprocessing helper function for each job + :meth:`.AlignMixin.align_arguments` + Job method for generating arguments for the helper function + :kaldi_steps:`align_si` + Reference Kaldi script + :kaldi_steps:`align_fmllr` + Reference Kaldi script + """ + begin = time.time() + log_directory = self.working_log_directory + + arguments = self.align_arguments() + if self.use_mp: + run_mp(align_func, arguments, log_directory) + else: + run_non_mp(align_func, arguments, log_directory) + + self.compile_information() + error_logs = [] + for j in arguments: + + with open(j.log_path, "r", encoding="utf8") as f: + for line in f: + if line.strip().startswith("ERROR"): + error_logs.append(j.log_path) + break + if error_logs: + raise AlignmentError(error_logs) + self.logger.debug(f"Alignment round took {time.time() - begin}") + + def compile_information(self): + """ + Compiles information about alignment, namely what the overall log-likelihood was + and how many files were unaligned. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_information_func` + Multiprocessing helper function for each job + :meth:`.AlignMixin.compile_information_arguments` + Job method for generating arguments for the helper function + """ + compile_info_begin = time.time() + + jobs = self.compile_information_arguments() + + if self.use_mp: + alignment_info = run_mp( + compile_information_func, jobs, self.working_log_directory, True + ) + else: + alignment_info = run_non_mp( + compile_information_func, jobs, self.working_log_directory, True + ) + + avg_like_sum = 0 + avg_like_frames = 0 + average_logdet_sum = 0 + average_logdet_frames = 0 + beam_too_narrow_count = 0 + too_short_count = 0 + for data in alignment_info.values(): + beam_too_narrow_count += len(data["unaligned"]) + too_short_count += len(data["too_short"]) + avg_like_frames += data["total_frames"] + avg_like_sum += data["log_like"] * data["total_frames"] + if "logdet_frames" in data: + average_logdet_frames += data["logdet_frames"] + average_logdet_sum += data["logdet"] * data["logdet_frames"] + + if not avg_like_frames: + self.logger.warning( + "No files were aligned, this likely indicates serious problems with the aligner." + ) + else: + if too_short_count: + self.logger.debug( + f"There were {too_short_count} utterances that were too short to be aligned." + ) + if beam_too_narrow_count: + self.logger.debug( + f"There were {beam_too_narrow_count} utterances that could not be aligned with " + f"the current beam settings." + ) + average_log_like = avg_like_sum / avg_like_frames + if average_logdet_sum: + average_log_like += average_logdet_sum / average_logdet_frames + self.logger.debug(f"Average per frame likelihood for alignment: {average_log_like}") + self.logger.debug(f"Compiling information took {time.time() - compile_info_begin}") + + @property + @abstractmethod + def working_directory(self) -> str: + """Working directory""" + ... + + @property + @abstractmethod + def working_log_directory(self) -> str: + """Working log directory""" + ... + + @property + def model_path(self) -> str: + """Acoustic model file path""" + return os.path.join(self.working_directory, "final.mdl") + + @property + def alignment_model_path(self) -> str: + """Acoustic model file path for speaker-independent alignment""" + path = os.path.join(self.working_directory, "final.alimdl") + if os.path.exists(path): + return path + return self.model_path diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py new file mode 100644 index 00000000..78932935 --- /dev/null +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -0,0 +1,1025 @@ +""" +Alignment multiprocessing functions +----------------------------------- + +""" +from __future__ import annotations + +import multiprocessing as mp +import os +import re +import subprocess +import sys +import traceback +from queue import Empty +from typing import TYPE_CHECKING, NamedTuple, Union + +from montreal_forced_aligner.textgrid import ( + CtmInterval, + export_textgrid, + generate_tiers, + parse_from_phone, + parse_from_word, + parse_from_word_no_cleanup, + process_ctm_line, +) +from montreal_forced_aligner.utils import Stopped, thirdparty_binary + +if TYPE_CHECKING: + from montreal_forced_aligner.abc import CtmErrorDict, MetaDict, ReversedMappingType + from montreal_forced_aligner.corpus.classes import File, Speaker, Utterance + from montreal_forced_aligner.dictionary import DictionaryData + + +queue_polling_timeout = 1 + +__all__ = [ + "PhoneCtmProcessWorker", + "CleanupWordCtmProcessWorker", + "NoCleanupWordCtmProcessWorker", + "CombineProcessWorker", + "ExportPreparationProcessWorker", + "ExportTextGridProcessWorker", + "align_func", + "ali_to_ctm_func", + "compile_information_func", + "compile_train_graphs_func", +] + + +class AliToCtmArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.ali_to_ctm_func`""" + + log_path: str + dictionaries: list[str] + ali_paths: dict[str, str] + text_int_paths: dict[str, str] + word_boundary_int_paths: dict[str, str] + frame_shift: float + model_path: str + ctm_paths: dict[str, str] + word_mode: bool + + +class CleanupWordCtmArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CleanupWordCtmProcessWorker`""" + + ctm_paths: dict[str, str] + dictionaries: list[str] + utterances: dict[str, dict[str, Utterance]] + dictionary_data: dict[str, DictionaryData] + + +class NoCleanupWordCtmArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.NoCleanupWordCtmProcessWorker`""" + + ctm_paths: dict[str, str] + dictionaries: list[str] + utterances: dict[str, dict[str, Utterance]] + dictionary_data: dict[str, DictionaryData] + + +class PhoneCtmArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmProcessWorker`""" + + ctm_paths: dict[str, str] + dictionaries: list[str] + utterances: dict[str, dict[str, Utterance]] + reversed_phone_mappings: dict[str, ReversedMappingType] + positions: dict[str, list[str]] + + +class CombineCtmArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CombineProcessWorker`""" + + dictionaries: list[str] + files: dict[str, File] + speakers: dict[str, Speaker] + dictionary_data: dict[str, DictionaryData] + cleanup_textgrids: bool + + +class ExportTextGridArguments(NamedTuple): + """Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker`""" + + files: dict[str, File] + frame_shift: int + output_directory: str + backup_output_directory: str + + +class CompileInformationArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_information_func`""" + + align_log_paths: str + + +class CompileTrainGraphsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_train_graphs_func`""" + + log_path: str + dictionaries: list[str] + tree_path: str + model_path: str + text_int_paths: dict[str, str] + disambig_paths: dict[str, str] + lexicon_fst_paths: dict[str, str] + fst_scp_paths: dict[str, str] + + +class AlignArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.align_func`""" + + log_path: str + dictionaries: list[str] + fst_scp_paths: dict[str, str] + feature_strings: dict[str, str] + model_path: str + ali_paths: dict[str, str] + align_options: MetaDict + + +def compile_train_graphs_func( + log_path: str, + dictionaries: list[str], + tree_path: str, + model_path: str, + text_int_paths: dict[str, str], + disambig_path: str, + lexicon_fst_paths: dict[str, str], + fst_scp_paths: dict[str, str], +) -> None: + """ + Multiprocessing function to compile training graphs + + See Also + -------- + :meth:`.AlignMixin.compile_train_graphs` + Main function that calls this function in parallel + :meth:`.AlignMixin.compile_train_graphs_arguments` + Job method for generating arguments for this function + :kaldi_src:`compile-train-graphs` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + tree_path: str + Path to the acoustic model tree file + model_path: str + Path to the acoustic model file + text_int_paths: dict[str, str] + Dictionary of text int files per dictionary name + disambig_path: str + Disambiguation symbol int file + lexicon_fst_paths: dict[str, str] + Dictionary of L.fst files per dictionary name + fst_scp_paths: dict[str, str] + Dictionary of utterance FST scp files per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + fst_scp_path = fst_scp_paths[dict_name] + fst_ark_path = fst_scp_path.replace(".scp", ".ark") + text_path = text_int_paths[dict_name] + log_file.write(f"{dict_name}\t{fst_scp_path}\t{fst_ark_path}\t{text_path}\n\n") + log_file.flush() + proc = subprocess.Popen( + [ + thirdparty_binary("compile-train-graphs"), + f"--read-disambig-syms={disambig_path}", + tree_path, + model_path, + lexicon_fst_paths[dict_name], + f"ark:{text_path}", + f"ark,scp:{fst_ark_path},{fst_scp_path}", + ], + stderr=log_file, + env=os.environ, + ) + proc.communicate() + + +def align_func( + log_path: str, + dictionaries: list[str], + fst_scp_paths: dict[str, str], + feature_strings: dict[str, str], + model_path: str, + ali_paths: dict[str, str], + align_options: MetaDict, +): + """ + Multiprocessing function for alignment. + + See Also + -------- + :meth:`.AlignMixin.align_utterances` + Main function that calls this function in parallel + :meth:`.AlignMixin.align_arguments` + Job method for generating arguments for this function + :kaldi_src:`align-equal-compiled` + Relevant Kaldi binary + :kaldi_src:`gmm-boost-silence` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + fst_scp_paths: dict[str, str] + Dictionary of FST scp file paths per dictionary name + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + model_path: str + Path to the acoustic model file + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + align_options: dict[str, Any] + Options for alignment + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + fst_path = fst_scp_paths[dict_name] + ali_path = ali_paths[dict_name] + com = [ + thirdparty_binary("gmm-align-compiled"), + f"--transition-scale={align_options['transition_scale']}", + f"--acoustic-scale={align_options['acoustic_scale']}", + f"--self-loop-scale={align_options['self_loop_scale']}", + f"--beam={align_options['beam']}", + f"--retry-beam={align_options['retry_beam']}", + "--careful=false", + "-", + f"scp:{fst_path}", + feature_string, + f"ark:{ali_path}", + ] + + boost_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-boost-silence"), + f"--boost={align_options['boost_silence']}", + align_options["optional_silence_csl"], + model_path, + "-", + ], + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + align_proc = subprocess.Popen( + com, stderr=log_file, stdin=boost_proc.stdout, env=os.environ + ) + align_proc.communicate() + + +def compile_information_func(align_log_path: str) -> dict[str, Union[list[str], float, int]]: + """ + Multiprocessing function for compiling information about alignment + + See Also + -------- + :meth:`.AlignMixin.compile_information` + Main function that calls this function in parallel + + Parameters + ---------- + align_log_path: str + Log path for alignment + + Returns + ------- + dict[str, Union[list[str], float, int]] + Information about log-likelihood and number of unaligned files + """ + average_logdet_pattern = re.compile( + r"Overall average logdet is (?P[-.,\d]+) over (?P[.\d+e]+) frames" + ) + log_like_pattern = re.compile( + r"^LOG .* Overall log-likelihood per frame is (?P[-0-9.]+) over (?P\d+) frames.*$" + ) + + decode_error_pattern = re.compile( + r"^WARNING .* Did not successfully decode file (?P.*?), .*$" + ) + + data = {"unaligned": [], "too_short": [], "log_like": 0, "total_frames": 0} + with open(align_log_path, "r", encoding="utf8") as f: + for line in f: + decode_error_match = re.match(decode_error_pattern, line) + if decode_error_match: + data["unaligned"].append(decode_error_match.group("utt")) + continue + log_like_match = re.match(log_like_pattern, line) + if log_like_match: + log_like = log_like_match.group("log_like") + frames = log_like_match.group("frames") + data["log_like"] = float(log_like) + data["total_frames"] = int(frames) + m = re.search(average_logdet_pattern, line) + if m: + logdet = float(m.group("logdet")) + frames = float(m.group("frames")) + data["logdet"] = logdet + data["logdet_frames"] = frames + return data + + +def ali_to_ctm_func( + log_path: str, + dictionaries: list[str], + ali_paths: dict[str, str], + text_int_paths: dict[str, str], + word_boundary_int_paths: dict[str, str], + frame_shift: float, + model_path: str, + ctm_paths: dict[str, str], + word_mode: bool, +) -> None: + """ + Multiprocessing function to convert alignment archives into CTM files + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that calls this function in parallel + :meth:`.CorpusAligner.ali_to_word_ctm_arguments` + Job method for generating arguments for this function + :meth:`.CorpusAligner.ali_to_phone_ctm_arguments` + Job method for generating arguments for this function + :kaldi_src:`linear-to-nbest` + Relevant Kaldi binary + :kaldi_src:`lattice-determinize-pruned` + Relevant Kaldi binary + :kaldi_src:`lattice-align-words` + Relevant Kaldi binary + :kaldi_src:`lattice-to-phone-lattice` + Relevant Kaldi binary + :kaldi_src:`nbest-to-ctm` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + text_int_paths: dict[str, str] + Dictionary of text int files per dictionary name + word_boundary_int_paths: dict[str, str] + Dictionary of word boundary int files per dictionary name + frame_shift: float + Frame shift of feature generation in seconds + model_path: str + Path to the acoustic model file + ctm_paths: dict[str, str] + Dictionary of CTM files per dictionary name + word_mode: bool + Flag for whether to parse words or phones + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + ali_path = ali_paths[dict_name] + text_int_path = text_int_paths[dict_name] + ctm_path = ctm_paths[dict_name] + word_boundary_int_path = word_boundary_int_paths[dict_name] + if os.path.exists(ctm_path): + return + lin_proc = subprocess.Popen( + [ + thirdparty_binary("linear-to-nbest"), + "ark:" + ali_path, + "ark:" + text_int_path, + "", + "", + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + align_words_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-align-words"), + word_boundary_int_path, + model_path, + "ark:-", + "ark:-", + ], + stdin=lin_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + if word_mode: + nbest_proc = subprocess.Popen( + [ + thirdparty_binary("nbest-to-ctm"), + f"--frame-shift={frame_shift}", + "ark:-", + ctm_path, + ], + stderr=log_file, + stdin=align_words_proc.stdout, + env=os.environ, + ) + else: + phone_proc = subprocess.Popen( + [thirdparty_binary("lattice-to-phone-lattice"), model_path, "ark:-", "ark:-"], + stdout=subprocess.PIPE, + stdin=align_words_proc.stdout, + stderr=log_file, + env=os.environ, + ) + nbest_proc = subprocess.Popen( + [ + thirdparty_binary("nbest-to-ctm"), + f"--frame-shift={frame_shift}", + "ark:-", + ctm_path, + ], + stdin=phone_proc.stdout, + stderr=log_file, + env=os.environ, + ) + nbest_proc.communicate() + + +class NoCleanupWordCtmProcessWorker(mp.Process): + """ + Multiprocessing worker for loading word CTM files without any clean up + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that runs this worker in parallel + + Parameters + ---------- + job_name: int + Job name + to_process_queue: :class:`~multiprocessing.Queue` + Return queue of jobs for later workers to process + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check for processing + error_catching: dict[tuple[str, int], str] + Dictionary for storing errors encountered + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.NoCleanupWordCtmArguments` + Arguments to pass to the CTM processing function + """ + + def __init__( + self, + job_name: int, + to_process_queue: mp.Queue, + stopped: Stopped, + error_catching: CtmErrorDict, + arguments: NoCleanupWordCtmArguments, + ): + mp.Process.__init__(self) + self.job_name = job_name + self.dictionaries = arguments.dictionaries + self.ctm_paths = arguments.ctm_paths + self.to_process_queue = to_process_queue + self.stopped = stopped + self.error_catching = error_catching + + # Corpus information + self.utterances = arguments.utterances + + # Dictionary information + self.dictionary_data = arguments.dictionary_data + + def run(self) -> None: + """ + Run the word processing with no clean up + """ + current_file_data = {} + + def process_current(cur_utt: Utterance, current_labels: list[CtmInterval]): + """Process current stack of intervals""" + actual_labels = parse_from_word_no_cleanup( + current_labels, self.dictionary_data[dict_name].reversed_words_mapping + ) + current_file_data[cur_utt.name] = actual_labels + + def process_current_file(cur_file: str): + """Process current file and add to return queue""" + self.to_process_queue.put(("word", cur_file, current_file_data)) + + cur_utt = None + cur_file = None + utt_begin = 0 + current_labels = [] + try: + for dict_name in self.dictionaries: + with open(self.ctm_paths[dict_name], "r") as word_file: + for line in word_file: + line = line.strip() + if not line: + continue + interval = process_ctm_line(line) + utt = interval.utterance + if cur_utt is None: + cur_utt = self.utterances[dict_name][utt] + utt_begin = cur_utt.begin + cur_file = cur_utt.file_name + + if utt != cur_utt: + process_current(cur_utt, current_labels) + cur_utt = self.utterances[dict_name][utt] + file_name = cur_utt.file_name + if file_name != cur_file: + process_current_file(cur_file) + current_file_data = {} + cur_file = file_name + current_labels = [] + if utt_begin: + interval.shift_times(utt_begin) + current_labels.append(interval) + if current_labels: + process_current(cur_utt, current_labels) + process_current_file(cur_file) + except Exception: + self.stopped.stop() + exc_type, exc_value, exc_traceback = sys.exc_info() + self.error_catching[("word", self.job_name)] = "\n".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + + +class CleanupWordCtmProcessWorker(mp.Process): + """ + Multiprocessing worker for loading word CTM files with cleaning up MFA-internal modifications + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that runs this worker in parallel + + Parameters + ---------- + job_name: int + Job name + to_process_queue: :class:`~multiprocessing.Queue` + Return queue of jobs for later workers to process + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check for processing + error_catching: dict[tuple[str, int], str] + Dictionary for storing errors encountered + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.CleanupWordCtmArguments` + Arguments to pass to the CTM processing function + """ + + def __init__( + self, + job_name: int, + to_process_queue: mp.Queue, + stopped: Stopped, + error_catching: CtmErrorDict, + arguments: CleanupWordCtmArguments, + ): + mp.Process.__init__(self) + self.job_name = job_name + self.dictionaries = arguments.dictionaries + self.ctm_paths = arguments.ctm_paths + self.to_process_queue = to_process_queue + self.stopped = stopped + self.error_catching = error_catching + + # Corpus information + self.utterances = arguments.utterances + + # Dictionary information + self.dictionary_data = arguments.dictionary_data + + def run(self) -> None: + """ + Run the word processing with clean up + """ + current_file_data = {} + + def process_current(cur_utt: Utterance, current_labels: list[CtmInterval]) -> None: + """Process current stack of intervals""" + text = cur_utt.text.split() + actual_labels = parse_from_word(current_labels, text, self.dictionary_data[dict_name]) + + current_file_data[cur_utt.name] = actual_labels + + def process_current_file(cur_file: str) -> None: + """Process current file and add to return queue""" + self.to_process_queue.put(("word", cur_file, current_file_data)) + + cur_utt = None + cur_file = None + utt_begin = 0 + current_labels = [] + try: + for dict_name in self.dictionaries: + ctm_path = self.ctm_paths[dict_name] + with open(ctm_path, "r") as word_file: + for line in word_file: + line = line.strip() + if not line: + continue + interval = process_ctm_line(line) + utt = interval.utterance + if cur_utt is None: + cur_utt = self.utterances[dict_name][utt] + utt_begin = cur_utt.begin + cur_file = cur_utt.file_name + + if utt != cur_utt: + process_current(cur_utt, current_labels) + cur_utt = self.utterances[dict_name][utt] + utt_begin = cur_utt.begin + file_name = cur_utt.file_name + if file_name != cur_file: + process_current_file(cur_file) + current_file_data = {} + cur_file = file_name + current_labels = [] + if utt_begin: + interval.shift_times(utt_begin) + current_labels.append(interval) + if current_labels: + process_current(cur_utt, current_labels) + process_current_file(cur_file) + except Exception: + self.stopped.stop() + exc_type, exc_value, exc_traceback = sys.exc_info() + self.error_catching[("word", self.job_name)] = "\n".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + + +class PhoneCtmProcessWorker(mp.Process): + """ + Multiprocessing worker for loading phone CTM files + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that runs this worker in parallel + + Parameters + ---------- + job_name: int + Job name + to_process_queue: :class:`~multiprocessing.Queue` + Return queue of jobs for later workers to process + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check for processing + error_catching: dict[tuple[str, int], str] + Dictionary for storing errors encountered + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneCtmArguments` + Arguments to pass to the CTM processing function + """ + + def __init__( + self, + job_name: int, + to_process_queue: mp.Queue, + stopped: Stopped, + error_catching: CtmErrorDict, + arguments: PhoneCtmArguments, + ): + mp.Process.__init__(self) + self.job_name = job_name + self.dictionaries = arguments.dictionaries + self.ctm_paths = arguments.ctm_paths + self.to_process_queue = to_process_queue + self.stopped = stopped + self.error_catching = error_catching + + self.utterances = arguments.utterances + + self.reversed_phone_mappings = arguments.reversed_phone_mappings + self.positions = arguments.positions + + def run(self) -> None: + """Run the phone processing""" + cur_utt = None + cur_file = None + utt_begin = 0 + current_labels = [] + + current_file_data = {} + + def process_current_utt(cur_utt: Utterance, current_labels: list[CtmInterval]) -> None: + """Process current stack of intervals""" + actual_labels = parse_from_phone( + current_labels, self.reversed_phone_mappings[dict_name], self.positions[dict_name] + ) + current_file_data[cur_utt.name] = actual_labels + + def process_current_file(cur_file: str) -> None: + """Process current file and add to return queue""" + self.to_process_queue.put(("phone", cur_file, current_file_data)) + + try: + for dict_name in self.dictionaries: + with open(self.ctm_paths[dict_name], "r") as word_file: + for line in word_file: + line = line.strip() + if not line: + continue + interval = process_ctm_line(line) + utt = interval.utterance + if cur_utt is None: + cur_utt = self.utterances[dict_name][utt] + cur_file = cur_utt.file_name + utt_begin = cur_utt.begin + + if utt != cur_utt: + + process_current_utt(cur_utt, current_labels) + + cur_utt = self.utterances[dict_name][utt] + file_name = cur_utt.file_name + utt_begin = cur_utt.begin + + if file_name != cur_file: + process_current_file(cur_file) + current_file_data = {} + cur_file = file_name + current_labels = [] + if utt_begin: + interval.shift_times(utt_begin) + current_labels.append(interval) + if current_labels: + process_current_utt(cur_utt, current_labels) + process_current_file(cur_file) + except Exception: + self.stopped.stop() + exc_type, exc_value, exc_traceback = sys.exc_info() + self.error_catching[("phone", self.job_name)] = ( + "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + + f"\n\n{len(self.utterances['english'].keys())}\nCould not find: {utt}\n" + + "\n".join(self.utterances["english"].keys()) + ) + + +class CombineProcessWorker(mp.Process): + """ + Multiprocessing worker for loading phone CTM files + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that runs this worker in parallel + + Parameters + ---------- + job_name: int + Job name + to_process_queue: :class:`~multiprocessing.Queue` + Input queue of phone and word ctms to combine + to_export_queue: :class:`~multiprocessing.Queue` + Export queue of combined CTMs + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check for processing + finished_combining: :class:`~montreal_forced_aligner.utils.Stopped` + Signal that this worker has finished combining all CTMs + error_catching: dict[tuple[str, int], str] + Dictionary for storing errors encountered + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.CombineCtmArguments` + Arguments to pass to the CTM combining function + """ + + def __init__( + self, + job_name: int, + to_process_queue: mp.Queue, + to_export_queue: mp.Queue, + stopped: Stopped, + finished_combining: Stopped, + error_catching: CtmErrorDict, + arguments: CombineCtmArguments, + ): + mp.Process.__init__(self) + self.job_name = job_name + self.to_process_queue = to_process_queue + self.to_export_queue = to_export_queue + self.stopped = stopped + self.finished_combining = finished_combining + self.error_catching = error_catching + + self.files = arguments.files + self.speakers = arguments.speakers + self.dictionary_data = arguments.dictionary_data + self.cleanup_textgrids = arguments.cleanup_textgrids + + for file in self.files.values(): + for s in file.speaker_ordering: + if s.name not in self.speakers: + continue + s.dictionary_data = self.dictionary_data[self.speakers[s.name].dictionary_name] + + def run(self) -> None: + """Run the combination function""" + + phone_data = {} + word_data = {} + while True: + try: + w_p, file_name, data = self.to_process_queue.get(timeout=queue_polling_timeout) + except Empty: + if self.finished_combining.stop_check(): + break + continue + self.to_process_queue.task_done() + if self.stopped.stop_check(): + continue + if w_p == "phone": + if file_name in word_data: + word_ctm = word_data.pop(file_name) + phone_ctm = data + else: + phone_data[file_name] = data + continue + else: + if file_name in phone_data: + phone_ctm = phone_data.pop(file_name) + word_ctm = data + else: + word_data[file_name] = data + continue + try: + file = self.files[file_name] + for u_name, u in file.utterances.items(): + if u_name not in word_ctm: + continue + u.speaker.dictionary_data = self.dictionary_data[ + self.speakers[u.speaker_name].dictionary_name + ] + u.word_labels = word_ctm[u_name] + u.phone_labels = phone_ctm[u_name] + processed_check = True + for s in file.speaker_ordering: + if s.name not in self.speakers: + continue + if not file.has_fully_aligned_speaker(s): + processed_check = False + break + if not processed_check: + continue + data = generate_tiers(file, cleanup_textgrids=self.cleanup_textgrids) + self.to_export_queue.put((file_name, data)) + except Exception: + self.stopped.stop() + exc_type, exc_value, exc_traceback = sys.exc_info() + self.error_catching[("combining", self.job_name)] = "\n".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + + +class ExportTextGridProcessWorker(mp.Process): + """ + Multiprocessing worker for exporting TextGrids + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that runs this worker in parallel + + Parameters + ---------- + for_write_queue: :class:`~multiprocessing.Queue` + Input queue of files to export + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check for processing + finished_processing: :class:`~montreal_forced_aligner.utils.Stopped` + Input signal that all jobs have been added and no more new ones will come in + textgrid_errors: dict[str, str] + Dictionary for storing errors encountered + arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridArguments` + Arguments to pass to the TextGrid export function + """ + + def __init__( + self, + for_write_queue: mp.Queue, + stopped: Stopped, + finished_processing: Stopped, + textgrid_errors: dict[str, str], + arguments: ExportTextGridArguments, + ): + mp.Process.__init__(self) + self.for_write_queue = for_write_queue + self.stopped = stopped + self.finished_processing = finished_processing + self.textgrid_errors = textgrid_errors + + self.files = arguments.files + self.output_directory = arguments.output_directory + self.backup_output_directory = arguments.backup_output_directory + + self.frame_shift = arguments.frame_shift + + def run(self) -> None: + """Run the exporter function""" + while True: + try: + file_name, data = self.for_write_queue.get(timeout=queue_polling_timeout) + except Empty: + if self.finished_processing.stop_check(): + break + continue + self.for_write_queue.task_done() + if self.stopped.stop_check(): + continue + try: + overwrite = True + file = self.files[file_name] + output_path = file.construct_output_path( + self.output_directory, self.backup_output_directory + ) + + export_textgrid(file, output_path, data, self.frame_shift, overwrite) + except Exception: + exc_type, exc_value, exc_traceback = sys.exc_info() + self.textgrid_errors[file_name] = "\n".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + + +class ExportPreparationProcessWorker(mp.Process): + """ + Multiprocessing worker for preparing CTMs for export + + See Also + -------- + :meth:`.CorpusAligner.ctms_to_textgrids_mp` + Main function that runs this worker in parallel + + Parameters + ---------- + to_export_queue: :class:`~multiprocessing.Queue` + Input queue of combined CTMs + for_write_queue: :class:`~multiprocessing.Queue` + Export queue of files to export + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check for processing + finished_combining: :class:`~montreal_forced_aligner.utils.Stopped` + Input signal that all CTMs have been combined + files: dict[str, File] + Files in corpus + """ + + def __init__( + self, + to_export_queue: mp.Queue, + for_write_queue: mp.Queue, + stopped: Stopped, + finished_combining: Stopped, + files: dict[str, File], + ): + mp.Process.__init__(self) + self.to_export_queue = to_export_queue + self.for_write_queue = for_write_queue + self.stopped = stopped + self.finished_combining = finished_combining + + self.files = files + + def run(self) -> None: + """Run the export preparation worker""" + export_data = {} + try: + while True: + try: + file_name, data = self.to_export_queue.get(timeout=queue_polling_timeout) + except Empty: + if self.finished_combining.stop_check(): + break + continue + self.to_export_queue.task_done() + if self.stopped.stop_check(): + continue + file = self.files[file_name] + if len(file.speaker_ordering) > 1: + if file_name not in export_data: + export_data[file_name] = data + else: + export_data[file_name].update(data) + if len(export_data[file_name]) == len(file.speaker_ordering): + data = export_data.pop(file_name) + self.for_write_queue.put((file_name, data)) + else: + self.for_write_queue.put((file_name, data)) + + for k, v in export_data.items(): + self.for_write_queue.put((k, v)) + except Exception: + self.stopped.stop() + raise diff --git a/montreal_forced_aligner/alignment/pretrained.py b/montreal_forced_aligner/alignment/pretrained.py new file mode 100644 index 00000000..813fcf18 --- /dev/null +++ b/montreal_forced_aligner/alignment/pretrained.py @@ -0,0 +1,428 @@ +"""Class definitions for aligning with pretrained acoustic models""" +from __future__ import annotations + +import os +import subprocess +import time +from collections import Counter, defaultdict +from typing import TYPE_CHECKING, NamedTuple, Optional + +import yaml + +from montreal_forced_aligner.abc import TopLevelMfaWorker +from montreal_forced_aligner.alignment.base import CorpusAligner +from montreal_forced_aligner.exceptions import KaldiProcessingError +from montreal_forced_aligner.helper import parse_old_features +from montreal_forced_aligner.models import AcousticModel +from montreal_forced_aligner.utils import log_kaldi_errors, run_mp, run_non_mp, thirdparty_binary + +if TYPE_CHECKING: + from argparse import Namespace + + from montreal_forced_aligner.abc import MetaDict + +__all__ = ["PretrainedAligner"] + + +def generate_pronunciations_func( + log_path: str, + dictionaries: list[str], + text_int_paths: dict[str, str], + word_boundary_paths: dict[str, str], + ali_paths: dict[str, str], + model_path: str, + pron_paths: dict[str, str], +): + """ + Multiprocessing function for generating pronunciations + + See Also + -------- + :meth:`.DictionaryTrainer.export_lexicons` + Main function that calls this function in parallel + :meth:`.DictionaryTrainer.generate_pronunciations_arguments` + Job method for generating arguments for this function + :kaldi_src:`linear-to-nbest` + Kaldi binary this uses + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + text_int_paths: dict[str, str] + Dictionary of text int files per dictionary name + word_boundary_paths: dict[str, str] + Dictionary of word boundary files per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + model_path: str + Path to acoustic model file + pron_paths: dict[str, str] + Dictionary of pronunciation archives per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + text_int_path = text_int_paths[dict_name] + word_boundary_path = word_boundary_paths[dict_name] + ali_path = ali_paths[dict_name] + pron_path = pron_paths[dict_name] + + lin_proc = subprocess.Popen( + [ + thirdparty_binary("linear-to-nbest"), + f"ark:{ali_path}", + f"ark:{text_int_path}", + "", + "", + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + align_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-align-words"), + word_boundary_path, + model_path, + "ark:-", + "ark:-", + ], + stdin=lin_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + + prons_proc = subprocess.Popen( + [thirdparty_binary("nbest-to-prons"), model_path, "ark:-", pron_path], + stdin=align_proc.stdout, + stderr=log_file, + env=os.environ, + ) + prons_proc.communicate() + + +class GeneratePronunciationsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.alignment.pretrained.generate_pronunciations_func`""" + + log_path: str + dictionaries: list[str] + text_int_paths: dict[str, str] + word_boundary_paths: dict[str, str] + ali_paths: dict[str, str] + model_path: str + pron_paths: dict[str, str] + + +class PretrainedAligner(CorpusAligner, TopLevelMfaWorker): + """ + Class for aligning a dataset using a pretrained acoustic model + + Parameters + ---------- + acoustic_model_path : str + Path to acoustic model + + See Also + -------- + :class:`~montreal_forced_aligner.alignment.base.CorpusAligner` + For dictionary and corpus parsing parameters and alignment parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + """ + + def __init__( + self, + acoustic_model_path: str, + **kwargs, + ): + self.acoustic_model = AcousticModel(acoustic_model_path) + kwargs.update(self.acoustic_model.parameters) + super().__init__(**kwargs) + + @property + def working_directory(self) -> str: + """Working directory""" + return self.workflow_directory + + def setup(self) -> None: + """Setup for alignment""" + if self.initialized: + return + begin = time.time() + try: + os.makedirs(self.working_log_directory, exist_ok=True) + check = self.check_previous_run() + if check: + self.logger.debug( + "There were some differences in the current run compared to the last one. " + "This may cause issues, run with --clean, if you hit an error." + ) + self.load_corpus() + self.acoustic_model.validate(self) + self.acoustic_model.export_model(self.working_directory) + self.acoustic_model.log_details(self.logger) + + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + self.initialized = True + self.logger.debug(f"Setup for alignment in {time.time() - begin} seconds") + + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse parameters from a config path or command-line arguments + + Parameters + ---------- + config_path: str + Config path + args: :class:`~argparse.Namespace` + Command-line arguments from argparse + unknown_args: list[str], optional + Extra command-line arguments + + Returns + ------- + dict[str, Any] + Configuration parameters + """ + global_params = {} + if config_path and os.path.exists(config_path): + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + data = parse_old_features(data) + for k, v in data.items(): + if k == "features": + global_params.update(v) + else: + global_params[k] = v + global_params.update(cls.parse_args(args, unknown_args)) + return global_params + + @property + def configuration(self) -> MetaDict: + """Configuration for aligner""" + config = super().configuration + config.update( + { + "acoustic_model": self.acoustic_model.name, + } + ) + return config + + @property + def backup_output_directory(self) -> Optional[str]: + """Backup output directory if overwriting is not allowed""" + if self.overwrite: + return None + return os.path.join(self.working_directory, "textgrids") + + @property + def workflow_identifier(self) -> str: + """Aligner identifier""" + return "pretrained_aligner" + + def align(self) -> None: + """Run the aligner""" + self.setup() + done_path = os.path.join(self.working_directory, "done") + dirty_path = os.path.join(self.working_directory, "dirty") + if os.path.exists(done_path): + self.logger.info("Alignment already done, skipping.") + return + try: + log_dir = os.path.join(self.working_directory, "log") + os.makedirs(log_dir, exist_ok=True) + self.compile_train_graphs() + + self.logger.info("Performing first-pass alignment...") + self.speaker_independent = True + self.align_utterances() + self.compile_information() + if self.uses_speaker_adaptation: + self.logger.info("Calculating fMLLR for speaker adaptation...") + self.calc_fmllr() + + self.speaker_independent = False + self.logger.info("Performing second-pass alignment...") + self.align_utterances() + + self.compile_information() + + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + with open(done_path, "w"): + pass + + +class DictionaryTrainer(PretrainedAligner): + """ + Aligner for calculating pronunciation probabilities of dictionary entries + + Parameters + ---------- + calculate_silence_probs: bool + Flag for whether to calculate silence probabilities, default is False + min_count: int + Specifies the minimum count of words to include in derived probabilities, + affects probabilities of infrequent words more, default is 1 + + See Also + -------- + :class:`~montreal_forced_aligner.alignment.pretrained.PretrainedAligner` + For dictionary and corpus parsing parameters and alignment parameters + """ + + def __init__( + self, + calculate_silence_probs: bool = False, + min_count: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.calculate_silence_probs = calculate_silence_probs + self.min_count = min_count + + def generate_pronunciations_arguments( + self, + ) -> list[GeneratePronunciationsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.alignment.pretrained.generate_pronunciations_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.alignment.pretrained.GeneratePronunciationsArguments`] + Arguments for processing + """ + return [ + GeneratePronunciationsArguments( + os.path.join(self.working_log_directory, f"generate_pronunciations.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + j.word_boundary_int_files(), + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + self.model_path, + j.construct_path_dictionary(self.working_directory, "prons", "scp"), + ) + for j in self.jobs + ] + + def export_lexicons(self, output_directory: str) -> None: + """ + Generate pronunciation probabilities for the dictionary + + Parameters + ---------- + output_directory: str + Directory in which to save new dictionaries + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.pretrained.generate_pronunciations_func` + Multiprocessing helper function for each job + :meth:`.DictionaryTrainer.generate_pronunciations_arguments` + Job method for generating arguments for helper function + + """ + os.makedirs(output_directory, exist_ok=True) + jobs = self.generate_pronunciations_arguments() + if self.use_mp: + run_mp(generate_pronunciations_func, jobs, self.working_log_directory) + else: + run_non_mp(generate_pronunciations_func, jobs, self.working_log_directory) + pron_counts = {} + utt_mapping = {} + for j in self.jobs: + args = jobs[j.name] + dict_data = j.dictionary_data() + for dict_name, pron_path in args.pron_paths.items(): + if dict_name not in pron_counts: + pron_counts[dict_name] = defaultdict(Counter) + utt_mapping[dict_name] = {} + word_lookup = dict_data[dict_name].reversed_words_mapping + phone_lookup = self.dictionary_mapping[dict_name].reversed_phone_mapping + with open(pron_path, "r", encoding="utf8") as f: + last_utt = None + for line in f: + line = line.split() + utt = line[0] + if utt not in utt_mapping[dict_name]: + if last_utt is not None: + utt_mapping[dict_name][last_utt].append("") + utt_mapping[dict_name][utt] = [""] + last_utt = utt + + word = word_lookup[int(line[3])] + if word == "": + utt_mapping[dict_name][utt].append(word) + else: + pron = tuple(phone_lookup[int(x)].split("_")[0] for x in line[4:]) + pron_string = " ".join(pron) + utt_mapping[dict_name][utt].append(word + " " + pron_string) + pron_counts[dict_name][word][pron] += 1 + for dict_name, dictionary in self.dictionary_mapping.items(): + counts = pron_counts[dict_name] + mapping = utt_mapping[dict_name] + if self.calculate_silence_probs: + sil_before_counts = Counter() + nonsil_before_counts = Counter() + sil_after_counts = Counter() + nonsil_after_counts = Counter() + sils = ["", "", ""] + for v in mapping.values(): + for i, w in enumerate(v): + if w in sils: + continue + prev_w = v[i - 1] + next_w = v[i + 1] + if prev_w in sils: + sil_before_counts[w] += 1 + else: + nonsil_before_counts[w] += 1 + if next_w in sils: + sil_after_counts[w] += 1 + else: + nonsil_after_counts[w] += 1 + + dictionary.pronunciation_probabilities = True + for word, prons in dictionary.words.items(): + if word not in counts: + for p in prons: + p["probability"] = 1 + else: + total = 0 + best_pron = 0 + best_count = 0 + for p in prons: + p["probability"] = self.min_count + if p["pronunciation"] in counts[word]: + p["probability"] += counts[word][p["pronunciation"]] + total += p["probability"] + if p["probability"] > best_count: + best_pron = p["pronunciation"] + best_count = p["probability"] + for p in prons: + if p["pronunciation"] == best_pron: + p["probability"] = 1 + else: + p["probability"] /= total + dictionary.words[word] = prons + output_path = os.path.join(output_directory, dict_name + ".txt") + dictionary.export_lexicon(output_path, probability=True) diff --git a/montreal_forced_aligner/command_line/__init__.py b/montreal_forced_aligner/command_line/__init__.py index 395b613b..1b856a48 100644 --- a/montreal_forced_aligner/command_line/__init__.py +++ b/montreal_forced_aligner/command_line/__init__.py @@ -2,25 +2,32 @@ Command line functionality ========================== - """ -from .adapt import run_adapt_model # noqa -from .align import run_align_corpus # noqa -from .anchor import run_anchor # noqa -from .classify_speakers import run_classify_speakers # noqa -from .create_segments import run_create_segments # noqa -from .g2p import run_g2p # noqa -from .mfa import create_parser, main # noqa -from .model import download_model, inspect_model, list_model, run_model, save_model # noqa -from .train_acoustic_model import run_train_acoustic_model # noqa -from .train_dictionary import run_train_dictionary # noqa -from .train_g2p import run_train_g2p # noqa -from .train_ivector_extractor import run_train_ivector_extractor # noqa -from .train_lm import run_train_lm # noqa -from .transcribe import run_transcribe_corpus # noqa -from .utils import validate_model_arg # noqa -from .validate import run_validate_corpus # noqa +from montreal_forced_aligner.command_line.adapt import run_adapt_model +from montreal_forced_aligner.command_line.align import run_align_corpus +from montreal_forced_aligner.command_line.anchor import run_anchor +from montreal_forced_aligner.command_line.classify_speakers import run_classify_speakers +from montreal_forced_aligner.command_line.create_segments import run_create_segments +from montreal_forced_aligner.command_line.g2p import run_g2p +from montreal_forced_aligner.command_line.mfa import create_parser, main +from montreal_forced_aligner.command_line.model import ( + download_model, + inspect_model, + list_model, + run_model, + save_model, +) +from montreal_forced_aligner.command_line.train_acoustic_model import run_train_acoustic_model +from montreal_forced_aligner.command_line.train_dictionary import run_train_dictionary +from montreal_forced_aligner.command_line.train_g2p import run_train_g2p +from montreal_forced_aligner.command_line.train_ivector_extractor import ( + run_train_ivector_extractor, +) +from montreal_forced_aligner.command_line.train_lm import run_train_lm +from montreal_forced_aligner.command_line.transcribe import run_transcribe_corpus +from montreal_forced_aligner.command_line.utils import validate_model_arg +from montreal_forced_aligner.command_line.validate import run_validate_corpus __all__ = [ "adapt", diff --git a/montreal_forced_aligner/command_line/adapt.py b/montreal_forced_aligner/command_line/adapt.py index a421931d..befdcb2d 100644 --- a/montreal_forced_aligner/command_line/adapt.py +++ b/montreal_forced_aligner/command_line/adapt.py @@ -2,23 +2,12 @@ from __future__ import annotations import os -import shutil import time -from typing import TYPE_CHECKING, Collection, Optional +from typing import TYPE_CHECKING, Optional -from montreal_forced_aligner.aligner import AdaptingAligner, PretrainedAligner, TrainableAligner +from montreal_forced_aligner.alignment import AdaptingAligner from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - align_yaml_to_config, - load_basic_align, - load_command_configuration, -) -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import AcousticModel -from montreal_forced_aligner.utils import get_mfa_version, log_config, setup_logger if TYPE_CHECKING: from argparse import Namespace @@ -26,7 +15,7 @@ __all__ = ["adapt_model", "validate_args", "run_adapt_model"] -def adapt_model(args: Namespace, unknown_args: Optional[Collection[str]] = None) -> None: +def adapt_model(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Run the acoustic model adaptation @@ -34,211 +23,42 @@ def adapt_model(args: Namespace, unknown_args: Optional[Collection[str]] = None) ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - command = "adapt" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - if args.config_path: - align_config, dictionary_config = align_yaml_to_config(args.config_path) - else: - align_config, dictionary_config = load_basic_align() - align_config.update_from_args(args) - if unknown_args: - align_config.update_from_unknown_args(unknown_args) - conf_path = os.path.join(data_directory, "config.yml") - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - logger.debug("ALIGN CONFIG:") - log_config(logger, align_config) - logger.debug("DICTIONARY CONFIG:") - log_config(logger, dictionary_config) - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": all_begin, - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - "dictionary_path": args.dictionary_path, - "acoustic_model_path": args.acoustic_model_path, - }, + adapter = AdaptingAligner( + acoustic_model_path=args.acoustic_model_path, + corpus_directory=args.corpus_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **AdaptingAligner.parse_parameters(args.config_path, args, unknown_args), ) - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - or conf["dictionary_path"] != args.dictionary_path - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["dictionary_path"] != args.dictionary_path: - logger.debug( - f"Previous run used dictionary path {conf['dictionary_path']} " - f"(new run: {args.dictionary_path})" - ) - if conf["acoustic_model_path"] != args.acoustic_model_path: - logger.debug( - f"Previous run used acoustic model path {conf['acoustic_model_path']} " - f"(new run: {args.acoustic_model_path})" - ) - os.makedirs(data_directory, exist_ok=True) - model_directory = os.path.join(data_directory, "acoustic_models") - os.makedirs(model_directory, exist_ok=True) - acoustic_model = AcousticModel(args.acoustic_model_path, root_directory=model_directory) - dictionary_config.update(acoustic_model.meta) - acoustic_model.log_details(logger) - debug = getattr(args, "debug", False) - audio_dir = None - if args.audio_directory: - audio_dir = args.audio_directory try: - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - num_jobs=args.num_jobs, - sample_rate=align_config.feature_config.sample_frequency, - logger=logger, - use_mp=align_config.use_mp, - audio_directory=audio_dir, - ) - logger.info(corpus.speaker_utterance_info()) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - logger=logger, - ) - acoustic_model.validate(dictionary) - - begin = time.time() - previous = PretrainedAligner( - corpus, - dictionary, - acoustic_model, - align_config, - temp_directory=data_directory, - debug=debug, - logger=logger, - ) - if args.full_train: - training_config, dictionary = acoustic_model.adaptation_config() - training_config.training_configs[0].update( - {"beam": align_config.beam, "retry_beam": align_config.retry_beam} - ) - training_config.update_from_align(align_config) - logger.debug("ADAPT TRAINING CONFIG:") - log_config(logger, training_config) - a = TrainableAligner( - corpus, - dictionary, - training_config, - align_config, - temp_directory=data_directory, - debug=debug, - logger=logger, - pretrained_aligner=previous, - ) - logger.debug(f"Setup adapter trainer in {time.time() - begin} seconds") - a.verbose = args.verbose - generate_final_alignments = True - if args.output_directory is None: - generate_final_alignments = False - else: - os.makedirs(args.output_directory, exist_ok=True) + adapter.adapt() + generate_final_alignments = True + if args.output_directory is None: + generate_final_alignments = False + else: + os.makedirs(args.output_directory, exist_ok=True) + export_model = True + if args.output_model_path is None: + export_model = False + if generate_final_alignments: begin = time.time() - a.train(generate_final_alignments) - logger.debug(f"Trained adapted model in {time.time() - begin} seconds") - if args.output_model_path is not None: - a.save(args.output_model_path, root_directory=model_directory) - - if generate_final_alignments: - a.export_textgrids(args.output_directory) - - a.save(args.output_model_path, root_directory=model_directory) - else: - a = AdaptingAligner( - corpus, - dictionary, - previous, - align_config, - temp_directory=data_directory, - debug=debug, - logger=logger, + adapter.align() + adapter.logger.debug( + f"Generated alignments with adapted model in {time.time() - begin} seconds" ) - logger.debug(f"Setup adapter trainer in {time.time() - begin} seconds") - a.verbose = args.verbose - generate_final_alignments = True - if args.output_directory is None: - generate_final_alignments = False - else: - os.makedirs(args.output_directory, exist_ok=True) - begin = time.time() - a.train() - logger.debug(f"Mapped adapted model in {time.time() - begin} seconds") - if args.output_model_path is not None: - a.save(args.output_model_path, root_directory=model_directory) - if generate_final_alignments: - begin = time.time() - a.align() - logger.debug( - f"Generated alignments with adapted model in {time.time() - begin} seconds" - ) - - if generate_final_alignments: - a.export_textgrids(args.output_directory) - - a.save(args.output_model_path, root_directory=model_directory) - - logger.info("All done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") + adapter.export_files(args.output_directory) + if export_model: + adapter.export_model(args.output_model_path) except Exception: - conf["dirty"] = True + adapter.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + adapter.cleanup() def validate_args(args: Namespace) -> None: @@ -284,7 +104,7 @@ def validate_args(args: Namespace) -> None: args.acoustic_model_path = validate_model_arg(args.acoustic_model_path, "acoustic") -def run_adapt_model(args: Namespace, unknown_args: Optional[Collection] = None) -> None: +def run_adapt_model(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Wrapper function for running acoustic model adaptation @@ -292,7 +112,7 @@ def run_adapt_model(args: Namespace, unknown_args: Optional[Collection] = None) ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/align.py b/montreal_forced_aligner/command_line/align.py index 33892702..c9d763c6 100644 --- a/montreal_forced_aligner/command_line/align.py +++ b/montreal_forced_aligner/command_line/align.py @@ -2,23 +2,11 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional -from montreal_forced_aligner.aligner import PretrainedAligner +from montreal_forced_aligner.alignment import PretrainedAligner from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - align_yaml_to_config, - load_basic_align, - load_command_configuration, -) -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import AcousticModel -from montreal_forced_aligner.utils import log_config, setup_logger if TYPE_CHECKING: from argparse import Namespace @@ -27,7 +15,7 @@ __all__ = ["align_corpus", "validate_args", "run_align_corpus"] -def align_corpus(args: Namespace, unknown_args: Optional[list] = None) -> None: +def align_corpus(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Run the alignment @@ -35,156 +23,25 @@ def align_corpus(args: Namespace, unknown_args: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "align" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - if args.config_path: - align_config, dictionary_config = align_yaml_to_config(args.config_path) - else: - align_config, dictionary_config = load_basic_align() - align_config.update_from_args(args) - dictionary_config.update_from_args(args) - if unknown_args: - align_config.update_from_unknown_args(unknown_args) - dictionary_config.update_from_unknown_args(unknown_args) - conf_path = os.path.join(data_directory, "config.yml") - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - logger.debug("ALIGN CONFIG:") - log_config(logger, align_config) - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": all_begin, - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - "dictionary_path": args.dictionary_path, - "acoustic_model_path": args.acoustic_model_path, - }, + aligner = PretrainedAligner( + acoustic_model_path=args.acoustic_model_path, + corpus_directory=args.corpus_directory, + audio_directory=args.audio_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **PretrainedAligner.parse_parameters(args.config_path, args, unknown_args), ) - - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - or conf["dictionary_path"] != args.dictionary_path - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["dictionary_path"] != args.dictionary_path: - logger.debug( - f"Previous run used dictionary path {conf['dictionary_path']} " - f"(new run: {args.dictionary_path})" - ) - if conf["acoustic_model_path"] != args.acoustic_model_path: - logger.debug( - f"Previous run used acoustic model path {conf['acoustic_model_path']} " - f"(new run: {args.acoustic_model_path})" - ) - - os.makedirs(data_directory, exist_ok=True) - model_directory = os.path.join(data_directory, "acoustic_models") - os.makedirs(model_directory, exist_ok=True) - os.makedirs(args.output_directory, exist_ok=True) - acoustic_model = AcousticModel(args.acoustic_model_path, root_directory=model_directory) - dictionary_config.update(acoustic_model.meta) - acoustic_model.log_details(logger) - audio_dir = None - if args.audio_directory: - audio_dir = args.audio_directory try: - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - num_jobs=args.num_jobs, - sample_rate=align_config.feature_config.sample_frequency, - logger=logger, - use_mp=align_config.use_mp, - audio_directory=audio_dir, - ) - logger.info(corpus.speaker_utterance_info()) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - logger=logger, - word_set=corpus.word_set, - ) - - acoustic_model.validate(dictionary) - - begin = time.time() - a = PretrainedAligner( - corpus, - dictionary, - acoustic_model, - align_config, - temp_directory=data_directory, - debug=getattr(args, "debug", False), - logger=logger, - ) - logger.debug(f"Setup pretrained aligner in {time.time() - begin} seconds") - a.verbose = args.verbose - - begin = time.time() - a.align() - logger.debug(f"Performed alignment in {time.time() - begin} seconds") - - begin = time.time() - a.export_textgrids(args.output_directory) - logger.debug(f"Exported TextGrids in {time.time() - begin} seconds") - logger.info("All done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") + aligner.align() + aligner.export_files(args.output_directory) except Exception: - conf["dirty"] = True + aligner.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + aligner.cleanup() def validate_args(args: Namespace) -> None: @@ -221,7 +78,7 @@ def validate_args(args: Namespace) -> None: args.acoustic_model_path = validate_model_arg(args.acoustic_model_path, "acoustic") -def run_align_corpus(args: Namespace, unknown_args: Optional[list] = None) -> None: +def run_align_corpus(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Wrapper function for running alignment @@ -229,7 +86,7 @@ def run_align_corpus(args: Namespace, unknown_args: Optional[list] = None) -> No ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/classify_speakers.py b/montreal_forced_aligner/command_line/classify_speakers.py index 39d3ae25..b1f75ab0 100644 --- a/montreal_forced_aligner/command_line/classify_speakers.py +++ b/montreal_forced_aligner/command_line/classify_speakers.py @@ -2,22 +2,11 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - classification_yaml_to_config, - load_basic_classification, - load_command_configuration, -) -from montreal_forced_aligner.corpus import Corpus from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import IvectorExtractorModel from montreal_forced_aligner.speaker_classifier import SpeakerClassifier -from montreal_forced_aligner.utils import setup_logger if TYPE_CHECKING: from argparse import Namespace @@ -25,7 +14,7 @@ __all__ = ["classify_speakers", "validate_args", "run_classify_speakers"] -def classify_speakers(args: Namespace, unknown_args: Optional[list] = None) -> None: +def classify_speakers(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Run the speaker classification @@ -33,129 +22,25 @@ def classify_speakers(args: Namespace, unknown_args: Optional[list] = None) -> N ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "classify_speakers" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - conf_path = os.path.join(data_directory, "config.yml") - if args.config_path: - classification_config = classification_yaml_to_config(args.config_path) - else: - classification_config = load_basic_classification() - classification_config.use_mp = not args.disable_mp - classification_config.overwrite = args.overwrite - if unknown_args: - classification_config.update_from_unknown_args(unknown_args) - classification_config.use_mp = not args.disable_mp - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": time.time(), - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - "ivector_extractor_path": args.ivector_extractor_path, - }, + classifier = SpeakerClassifier( + ivector_extractor_path=args.ivector_extractor_path, + corpus_directory=args.corpus_directory, + temporary_directory=args.temporary_directory, + **SpeakerClassifier.parse_parameters(args.config_path, args, unknown_args), ) - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["ivector_extractor_path"] != args.ivector_extractor_path: - logger.debug( - f"Previous run used ivector extractor path {conf['ivector_extractor_path']} " - f"(new run: {args.ivector_extractor_path})" - ) - - os.makedirs(data_directory, exist_ok=True) - os.makedirs(args.output_directory, exist_ok=True) try: - ivector_extractor = IvectorExtractorModel( - args.ivector_extractor_path, root_directory=data_directory - ) - corpus = Corpus( - args.corpus_directory, - data_directory, - sample_rate=ivector_extractor.feature_config.sample_frequency, - num_jobs=args.num_jobs, - logger=logger, - use_mp=classification_config.use_mp, - ) - begin = time.time() - a = SpeakerClassifier( - corpus, - ivector_extractor, - classification_config, - temp_directory=data_directory, - debug=getattr(args, "debug", False), - logger=logger, - num_speakers=args.num_speakers, - cluster=args.cluster, - ) - logger.debug(f"Setup speaker classifier in {time.time() - begin} seconds") - a.verbose = args.verbose - - begin = time.time() - a.cluster_utterances() - logger.debug(f"Performed clustering in {time.time() - begin} seconds") - - begin = time.time() - a.export_classification(args.output_directory) - logger.debug(f"Exported classification in {time.time() - begin} seconds") - logger.info("Done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") + classifier.cluster_utterances() + + classifier.export_files(args.output_directory) except Exception: - conf["dirty"] = True + classifier.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + classifier.cleanup() def validate_args(args: Namespace) -> None: @@ -189,7 +74,7 @@ def validate_args(args: Namespace) -> None: args.ivector_extractor_path = validate_model_arg(args.ivector_extractor_path, "ivector") -def run_classify_speakers(args: Namespace, unknown: Optional[list] = None) -> None: +def run_classify_speakers(args: Namespace, unknown: Optional[list[str]] = None) -> None: """ Wrapper function for running speaker classification @@ -197,7 +82,7 @@ def run_classify_speakers(args: Namespace, unknown: Optional[list] = None) -> No ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/create_segments.py b/montreal_forced_aligner/command_line/create_segments.py index 129667f4..a818bfa2 100644 --- a/montreal_forced_aligner/command_line/create_segments.py +++ b/montreal_forced_aligner/command_line/create_segments.py @@ -2,20 +2,10 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional -from montreal_forced_aligner.config import ( - TEMP_DIR, - load_basic_segmentation, - load_command_configuration, - segmentation_yaml_to_config, -) -from montreal_forced_aligner.corpus import Corpus from montreal_forced_aligner.exceptions import ArgumentError from montreal_forced_aligner.segmenter import Segmenter -from montreal_forced_aligner.utils import log_config, setup_logger if TYPE_CHECKING: from argparse import Namespace @@ -24,7 +14,7 @@ __all__ = ["create_segments", "validate_args", "run_create_segments"] -def create_segments(args: Namespace, unknown_args: Optional[list] = None) -> None: +def create_segments(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Run the sound file segmentation @@ -32,117 +22,23 @@ def create_segments(args: Namespace, unknown_args: Optional[list] = None) -> Non ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "create_segments" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - conf_path = os.path.join(data_directory, "config.yml") - if args.config_path: - segmentation_config = segmentation_yaml_to_config(args.config_path) - else: - segmentation_config = load_basic_segmentation() - segmentation_config.update_from_args(args) - if unknown_args: - segmentation_config.update_from_unknown_args(unknown_args) - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - log_config(logger, segmentation_config) - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": time.time(), - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - }, + + segmenter = Segmenter( + corpus_directory=args.corpus_directory, + temporary_directory=args.temporary_directory, + **Segmenter.parse_parameters(args.config_path, args, unknown_args), ) - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - - os.makedirs(data_directory, exist_ok=True) - os.makedirs(args.output_directory, exist_ok=True) try: - corpus = Corpus( - args.corpus_directory, - data_directory, - sample_rate=segmentation_config.feature_config.sample_frequency, - num_jobs=args.num_jobs, - logger=logger, - use_mp=segmentation_config.use_mp, - ignore_speakers=True, - ) - - begin = time.time() - a = Segmenter( - corpus, - segmentation_config, - temp_directory=data_directory, - debug=getattr(args, "debug", False), - logger=logger, - ) - logger.debug(f"Setup segmenter in {time.time() - begin} seconds") - a.verbose = args.verbose - - begin = time.time() - a.segment() - logger.debug(f"Performed segmentation in {time.time() - begin} seconds") - - begin = time.time() - a.export_segments(args.output_directory) - logger.debug(f"Exported segmentation in {time.time() - begin} seconds") - logger.info("Done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") + segmenter.segment() + segmenter.export_files(args.output_directory) except Exception: - conf["dirty"] = True + segmenter.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + segmenter.cleanup() def validate_args(args: Namespace) -> None: @@ -172,7 +68,7 @@ def validate_args(args: Namespace) -> None: raise ArgumentError("Corpus directory and output directory cannot be the same folder.") -def run_create_segments(args: Namespace, unknown: Optional[list] = None) -> None: +def run_create_segments(args: Namespace, unknown: Optional[list[str]] = None) -> None: """ Wrapper function for running sound file segmentation @@ -180,7 +76,7 @@ def run_create_segments(args: Namespace, unknown: Optional[list] = None) -> None ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/g2p.py b/montreal_forced_aligner/command_line/g2p.py index f5e33184..d465ae60 100644 --- a/montreal_forced_aligner/command_line/g2p.py +++ b/montreal_forced_aligner/command_line/g2p.py @@ -2,16 +2,15 @@ from __future__ import annotations import os -import shutil from typing import TYPE_CHECKING, Optional from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import TEMP_DIR -from montreal_forced_aligner.config.g2p_config import g2p_yaml_to_config, load_basic_g2p_config -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.g2p.generator import PyniniDictionaryGenerator as Generator -from montreal_forced_aligner.models import G2PModel -from montreal_forced_aligner.utils import setup_logger +from montreal_forced_aligner.g2p.generator import ( + OrthographicCorpusGenerator, + OrthographicWordListGenerator, + PyniniCorpusGenerator, + PyniniWordListGenerator, +) if TYPE_CHECKING: from argparse import Namespace @@ -20,7 +19,7 @@ __all__ = ["generate_dictionary", "validate_args", "run_g2p"] -def generate_dictionary(args: Namespace, unknown_args: Optional[list] = None) -> None: +def generate_dictionary(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Run the G2P command @@ -28,90 +27,52 @@ def generate_dictionary(args: Namespace, unknown_args: Optional[list] = None) -> ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - command = "g2p" - if not args.temp_directory: - temp_dir = TEMP_DIR - temp_dir = os.path.join(temp_dir, "G2P") - else: - temp_dir = os.path.expanduser(args.temp_directory) - if args.clean: - shutil.rmtree(os.path.join(temp_dir, "G2P"), ignore_errors=True) - shutil.rmtree(os.path.join(temp_dir, "models", "G2P"), ignore_errors=True) - if args.config_path: - g2p_config, dictionary_config = g2p_yaml_to_config(args.config_path) - else: - g2p_config, dictionary_config = load_basic_g2p_config() - g2p_config.use_mp = not args.disable_mp - if unknown_args: - g2p_config.update_from_unknown_args(unknown_args) - if os.path.isdir(args.input_path): - input_dir = os.path.expanduser(args.input_path) - corpus_name = os.path.basename(args.input_path) - if corpus_name == "": - args.input_path = os.path.dirname(args.input_path) - corpus_name = os.path.basename(args.input_path) - data_directory = os.path.join(temp_dir, corpus_name) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - - corpus = Corpus( - input_dir, - data_directory, - dictionary_config=dictionary_config, - num_jobs=args.num_jobs, - use_mp=g2p_config.use_mp, - parse_text_only_files=True, - ) - - word_set = corpus.word_set - if not args.include_bracketed: - word_set = [x for x in word_set if not dictionary_config.check_bracketed(x)] - else: - if getattr(args, "verbose", False): - log_level = "debug" + if args.g2p_model_path is None: + if os.path.isdir(args.input_path): + g2p = OrthographicCorpusGenerator( + corpus_directory=args.input_path, + temporary_directory=args.temporary_directory, + **OrthographicCorpusGenerator.parse_parameters( + args.config_path, args, unknown_args + ) + ) else: - log_level = "info" - logger = setup_logger(command, temp_dir, console_level=log_level) - word_set = [] - with open(args.input_path, "r", encoding="utf8") as f: - for line in f: - word_set.extend(line.strip().split()) - if not args.include_bracketed: - word_set = [x for x in word_set if not dictionary_config.check_bracketed(x)] - - logger.info( - f"Generating transcriptions for the {len(word_set)} word types found in the corpus..." - ) - if args.g2p_model_path is not None: - model = G2PModel( - args.g2p_model_path, root_directory=os.path.join(temp_dir, "models", "G2P") - ) - model.validate(word_set) - num_jobs = args.num_jobs - if not g2p_config.use_mp: - num_jobs = 1 - gen = Generator( - model, - word_set, - temp_directory=temp_dir, - num_jobs=num_jobs, - num_pronunciations=g2p_config.num_pronunciations, - logger=logger, - ) - gen.output(args.output_path) - model.clean_up() + g2p = OrthographicWordListGenerator( + word_list_path=args.input_path, + temporary_directory=args.temporary_directory, + **OrthographicWordListGenerator.parse_parameters( + args.config_path, args, unknown_args + ) + ) + else: - with open(args.output_path, "w", encoding="utf8") as f: - for word in word_set: - pronunciation = list(word) - f.write(f"{word} {' '.join(pronunciation)}\n") + if os.path.isdir(args.input_path): + g2p = PyniniCorpusGenerator( + g2p_model_path=args.g2p_model_path, + corpus_directory=args.input_path, + temporary_directory=args.temporary_directory, + **PyniniCorpusGenerator.parse_parameters(args.config_path, args, unknown_args) + ) + else: + g2p = PyniniWordListGenerator( + g2p_model_path=args.g2p_model_path, + word_list_path=args.input_path, + temporary_directory=args.temporary_directory, + **PyniniWordListGenerator.parse_parameters(args.config_path, args, unknown_args) + ) + + try: + g2p.setup() + g2p.export_pronunciations(args.output_path) + except Exception: + g2p.dirty = True + raise + finally: + g2p.cleanup() def validate_args(args: Namespace) -> None: @@ -134,7 +95,7 @@ def validate_args(args: Namespace) -> None: args.g2p_model_path = validate_model_arg(args.g2p_model_path, "g2p") -def run_g2p(args: Namespace, unknown: Optional[list] = None) -> None: +def run_g2p(args: Namespace, unknown: Optional[list[str]] = None) -> None: """ Wrapper function for running G2P @@ -142,7 +103,7 @@ def run_g2p(args: Namespace, unknown: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/mfa.py b/montreal_forced_aligner/command_line/mfa.py index 7b3fb96a..704f3469 100644 --- a/montreal_forced_aligner/command_line/mfa.py +++ b/montreal_forced_aligner/command_line/mfa.py @@ -33,13 +33,6 @@ ) from montreal_forced_aligner.exceptions import MFAError from montreal_forced_aligner.models import MODEL_TYPES -from montreal_forced_aligner.utils import ( - get_available_acoustic_models, - get_available_dictionaries, - get_available_g2p_models, - get_available_ivector_extractors, - get_available_language_models, -) if TYPE_CHECKING: from argparse import ArgumentParser @@ -49,7 +42,7 @@ BEGIN_DATE = datetime.now() -__all__ = ["ExitHooks", "history_save_handler", "create_parser", "main"] +__all__ = ["ExitHooks", "create_parser", "main"] class ExitHooks(object): @@ -75,41 +68,34 @@ def exit(self, code=0): def exc_handler(self, exc_type, exc, *args): """Handle and save exceptions""" self.exception = exc + self.exit_code = 1 - -def history_save_handler() -> None: - """ - Handler for saving history on exit. In addition to the command run, also saves exit code, whether - an exception was encountered, when the command was executed, and how long it took to run - """ - from montreal_forced_aligner.utils import get_mfa_version - - history_data = { - "command": " ".join(sys.argv), - "execution_time": time.time() - BEGIN, - "date": BEGIN_DATE, - "version": get_mfa_version(), - } - - if hooks.exit_code is not None: - history_data["exit_code"] = hooks.exit_code - history_data["exception"] = "" - elif hooks.exception is not None: - history_data["exit_code"] = 1 - history_data["exception"] = str(hooks.exception) - else: - history_data["exception"] = "" - history_data["exit_code"] = 0 - update_command_history(history_data) - if hooks.exception: - raise hooks.exception - - -acoustic_models = get_available_acoustic_models() -ivector_extractors = get_available_ivector_extractors() -language_models = get_available_language_models() -g2p_models = get_available_g2p_models() -dictionaries = get_available_dictionaries() + def history_save_handler(self) -> None: + """ + Handler for saving history on exit. In addition to the command run, also saves exit code, whether + an exception was encountered, when the command was executed, and how long it took to run + """ + from montreal_forced_aligner.utils import get_mfa_version + + history_data = { + "command": " ".join(sys.argv), + "execution_time": time.time() - BEGIN, + "date": BEGIN_DATE, + "version": get_mfa_version(), + } + + if self.exit_code is not None: + history_data["exit_code"] = self.exit_code + history_data["exception"] = "" + elif self.exception is not None: + history_data["exit_code"] = 1 + history_data["exception"] = str(self.exception) + else: + history_data["exception"] = "" + history_data["exit_code"] = 0 + update_command_history(history_data) + if self.exception: + raise self.exception def create_parser() -> ArgumentParser: @@ -137,9 +123,11 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool subparser.add_argument( "-t", "--temp_directory", + "--temporary_directory", + dest="temporary_directory", type=str, - default=GLOBAL_CONFIG["temp_directory"], - help=f"Temporary directory root to store MFA created files, default is {GLOBAL_CONFIG['temp_directory']}", + default=GLOBAL_CONFIG["temporary_directory"], + help=f"Temporary directory root to store MFA created files, default is {GLOBAL_CONFIG['temporary_directory']}", ) subparser.add_argument( "--disable_mp", @@ -188,6 +176,48 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool default=not GLOBAL_CONFIG["cleanup_textgrids"], ) + pretrained_acoustic = ", ".join(MODEL_TYPES["acoustic"].get_available_models()) + if not pretrained_acoustic: + pretrained_acoustic = ( + "you can use ``mfa model download acoustic`` to get pretrained MFA models" + ) + + pretrained_ivector = ", ".join(MODEL_TYPES["ivector"].get_available_models()) + if not pretrained_ivector: + pretrained_ivector = ( + "you can use ``mfa model download ivector`` to get pretrained MFA models" + ) + + pretrained_g2p = ", ".join(MODEL_TYPES["g2p"].get_available_models()) + if not pretrained_g2p: + pretrained_g2p = "you can use ``mfa model download g2p`` to get pretrained MFA models" + + pretrained_lm = ", ".join(MODEL_TYPES["language_model"].get_available_models()) + if not pretrained_lm: + pretrained_lm = ( + "you can use ``mfa model download language_model`` to get pretrained MFA models" + ) + + pretrained_dictionary = ", ".join(MODEL_TYPES["dictionary"].get_available_models()) + if not pretrained_dictionary: + pretrained_dictionary = ( + "you can use ``mfa model download dictionary`` to get MFA dictionaries" + ) + + dictionary_path_help = f"Full path to pronunciation dictionary, or saved dictionary name ({pretrained_dictionary})" + + acoustic_model_path_help = ( + f"Full path to pre-trained acoustic model, or saved model name ({pretrained_acoustic})" + ) + language_model_path_help = ( + f"Full path to pre-trained language model, or saved model name ({pretrained_lm})" + ) + ivector_model_path_help = f"Full path to pre-trained ivector extractor model, or saved model name ({pretrained_ivector})" + g2p_model_path_help = ( + f"Full path to pre-trained G2P model, or saved model name ({pretrained_g2p}). " + "If not specified, then orthographic transcription is split into pronunciations." + ) + parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="subcommand") @@ -200,14 +230,18 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool ) align_parser.add_argument("corpus_directory", help="Full path to the directory to align") align_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use" + "dictionary_path", + help=dictionary_path_help, + type=str, ) align_parser.add_argument( "acoustic_model_path", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(acoustic_models)})", + type=str, + help=acoustic_model_path_help, ) align_parser.add_argument( "output_directory", + type=str, help="Full path to output directory, will be created if it doesn't exist", ) align_parser.add_argument( @@ -232,15 +266,15 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool adapt_parser = subparsers.add_parser("adapt", help="Adapt an acoustic model to a new corpus") adapt_parser.add_argument("corpus_directory", help="Full path to the directory to align") - adapt_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use" - ) + adapt_parser.add_argument("dictionary_path", type=str, help=dictionary_path_help) adapt_parser.add_argument( "acoustic_model_path", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(acoustic_models)})", + type=str, + help=acoustic_model_path_help, ) adapt_parser.add_argument( "output_paths", + type=str, nargs="+", help="Path to save the new acoustic model, path to export aligned TextGrids, or both", ) @@ -251,12 +285,6 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool default="", help="Full path to save adapted acoustic model", ) - adapt_parser.add_argument( - "--full_train", - action="store_true", - help="Specify whether to do a round of speaker-adapted training rather than the default " - "remapping approach to adaptation", - ) adapt_parser.add_argument( "--config_path", type=str, default="", help="Path to config file to use for alignment" ) @@ -281,13 +309,12 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool "train", help="Train a new acoustic model on a corpus and optionally export alignments" ) train_parser.add_argument( - "corpus_directory", help="Full path to the source directory to align" - ) - train_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use", default="" + "corpus_directory", type=str, help="Full path to the source directory to align" ) + train_parser.add_argument("dictionary_path", type=str, help=dictionary_path_help, default="") train_parser.add_argument( "output_paths", + type=str, nargs="+", help="Path to save the new acoustic model, path to export aligned TextGrids, or both", ) @@ -323,16 +350,17 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool validate_parser = subparsers.add_parser("validate", help="Validate a corpus for use in MFA") validate_parser.add_argument( - "corpus_directory", help="Full path to the source directory to align" + "corpus_directory", type=str, help="Full path to the source directory to align" ) validate_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use", default="" + "dictionary_path", type=str, help=dictionary_path_help, default="" ) validate_parser.add_argument( "acoustic_model_path", + type=str, nargs="?", default="", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(acoustic_models)})", + help=acoustic_model_path_help, ) validate_parser.add_argument( "-s", @@ -342,6 +370,12 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool help="Number of characters of file names to use for determining speaker, " "default is to use directory names", ) + validate_parser.add_argument( + "--config_path", + type=str, + default="", + help="Path to config file to use for training and alignment", + ) validate_parser.add_argument( "--test_transcriptions", help="Test accuracy of transcriptions", action="store_true" ) @@ -350,6 +384,13 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool help="Skip acoustic feature generation and associated validation", action="store_true", ) + validate_parser.add_argument( + "-a", + "--audio_directory", + type=str, + default="", + help="Audio directory root to use for finding audio files", + ) add_global_options(validate_parser) g2p_parser = subparsers.add_parser( @@ -357,15 +398,17 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool ) g2p_parser.add_argument( "g2p_model_path", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(g2p_models)}). If not specified, then orthographic transcription is split into pronunciations.", + help=g2p_model_path_help, + type=str, nargs="?", ) g2p_parser.add_argument( "input_path", + type=str, help="Corpus to base word list on or a text file of words to generate pronunciations", ) - g2p_parser.add_argument("output_path", help="Path to save output dictionary") + g2p_parser.add_argument("output_path", type=str, help="Path to save output dictionary") g2p_parser.add_argument( "--include_bracketed", help="Included words enclosed by brackets, job_name.e. [...], (...), <...>", @@ -379,14 +422,18 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool train_g2p_parser = subparsers.add_parser( "train_g2p", help="Train a G2P model from a pronunciation dictionary" ) - train_g2p_parser.add_argument("dictionary_path", help="Location of existing dictionary") + train_g2p_parser.add_argument("dictionary_path", type=str, help=dictionary_path_help) - train_g2p_parser.add_argument("output_model_path", help="Desired location of generated model") + train_g2p_parser.add_argument( + "output_model_path", type=str, help="Desired location of generated model" + ) train_g2p_parser.add_argument( "--config_path", type=str, default="", help="Path to config file to use for G2P" ) train_g2p_parser.add_argument( + "--evaluate", "--validate", + dest="evaluate", action="store_true", help="Perform an analysis of accuracy training on " "most of the data and validating on an unseen subset", @@ -410,6 +457,7 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool "name", help="Name of language code to download, if not specified, " "will list all available languages", + type=str, nargs="?", ) help_message = "List of saved models" @@ -417,7 +465,11 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool "list", description=help_message, help=help_message ) model_list_parser.add_argument( - "model_type", choices=sorted(MODEL_TYPES), nargs="?", help="Type of model to list" + "model_type", + choices=sorted(MODEL_TYPES), + type=str, + nargs="?", + help="Type of model to list", ) help_message = "Inspect a model and output its metadata" @@ -427,11 +479,12 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool model_inspect_parser.add_argument( "model_type", choices=sorted(MODEL_TYPES), + type=str, nargs="?", help="Type of model to download", ) model_inspect_parser.add_argument( - "name", help="Name of pretrained model or path to MFA model to inspect" + "name", type=str, help="Name of pretrained model or path to MFA model to inspect" ) help_message = "Save a MFA model to the pretrained directory for name-based referencing" @@ -439,7 +492,7 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool "save", description=help_message, help=help_message ) model_save_parser.add_argument( - "model_type", choices=sorted(MODEL_TYPES), help="Type of MFA model" + "model_type", type=str, choices=sorted(MODEL_TYPES), help="Type of MFA model" ) model_save_parser.add_argument( "path", help="Path to MFA model to save for invoking with just its name" @@ -461,6 +514,7 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool ) train_lm_parser.add_argument( "source_path", + type=str, help="Full path to the source directory to train from, alternatively " "an ARPA format language model to convert for MFA use", ) @@ -481,7 +535,7 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool help="Weight factor for supplemental language model, defaults to 1.0", ) train_lm_parser.add_argument( - "--dictionary_path", help="Full path to the pronunciation dictionary to use", default="" + "--dictionary_path", type=str, help=dictionary_path_help, default="" ) train_lm_parser.add_argument( "--config_path", @@ -498,15 +552,15 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool train_dictionary_parser.add_argument( "corpus_directory", help="Full path to the directory to align" ) - train_dictionary_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use" - ) + train_dictionary_parser.add_argument("dictionary_path", type=str, help=dictionary_path_help) train_dictionary_parser.add_argument( "acoustic_model_path", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(acoustic_models)})", + type=str, + help=acoustic_model_path_help, ) train_dictionary_parser.add_argument( "output_directory", + type=str, help="Full path to output directory, will be created if it doesn't exist", ) train_dictionary_parser.add_argument( @@ -528,21 +582,12 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool ) train_ivector_parser.add_argument( "corpus_directory", - help="Full path to the source directory to train the ivector extractor", - ) - train_ivector_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use" - ) - train_ivector_parser.add_argument( - "acoustic_model_path", type=str, - default="", - help="Full path to acoustic model for alignment", + help="Full path to the source directory to train the ivector extractor", ) train_ivector_parser.add_argument( "output_model_path", type=str, - default="", help="Full path to save resulting ivector extractor", ) train_ivector_parser.add_argument( @@ -563,13 +608,15 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool ) classify_speakers_parser.add_argument( "corpus_directory", + type=str, help="Full path to the source directory to run speaker classification", ) classify_speakers_parser.add_argument( - "ivector_extractor_path", type=str, default="", help="Full path to ivector extractor model" + "ivector_extractor_path", type=str, default="", help=ivector_model_path_help ) classify_speakers_parser.add_argument( "output_directory", + type=str, help="Full path to output directory, will be created if it doesn't exist", ) @@ -595,6 +642,7 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool ) create_segments_parser.add_argument( "output_directory", + type=str, help="Full path to output directory, will be created if it doesn't exist", ) create_segments_parser.add_argument( @@ -607,21 +655,22 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool help="Transcribe utterances using an acoustic model, language model, and pronunciation dictionary", ) transcribe_parser.add_argument( - "corpus_directory", help="Full path to the directory to transcribe" - ) - transcribe_parser.add_argument( - "dictionary_path", help="Full path to the pronunciation dictionary to use" + "corpus_directory", type=str, help="Full path to the directory to transcribe" ) + transcribe_parser.add_argument("dictionary_path", type=str, help=dictionary_path_help) transcribe_parser.add_argument( "acoustic_model_path", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(acoustic_models)})", + type=str, + help=acoustic_model_path_help, ) transcribe_parser.add_argument( "language_model_path", - help=f"Full path to the archive containing pre-trained model or language ({', '.join(language_models)})", + type=str, + help=language_model_path_help, ) transcribe_parser.add_argument( "output_directory", + type=str, help="Full path to output directory, will be created if it doesn't exist", ) transcribe_parser.add_argument( @@ -658,9 +707,11 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool config_parser.add_argument( "-t", "--temp_directory", + "--temporary_directory", + dest="temporary_directory", type=str, default="", - help=f"Set the default temporary directory, default is {GLOBAL_CONFIG['temp_directory']}", + help=f"Set the default temporary directory, default is {GLOBAL_CONFIG['temporary_directory']}", ) config_parser.add_argument( "-j", @@ -750,9 +801,14 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool "download", help="DEPRECATED: Please use mfa model download instead." ) - history_parser.add_argument("depth", help="Number of commands to list", nargs="?", default=10) history_parser.add_argument( - "--verbose", help="Flag for whether to output additional information", action="store_true" + "depth", type=int, help="Number of commands to list", nargs="?", default=10 + ) + history_parser.add_argument( + "-v", + "--verbose", + help=f"Output debug messages, default is {GLOBAL_CONFIG['verbose']}", + action="store_true", ) _ = subparsers.add_parser( @@ -765,10 +821,34 @@ def add_global_options(subparser: argparse.ArgumentParser, textgrid_output: bool parser = create_parser() +def print_history(args): + depth = args.depth + history = load_command_history()[-depth:] + if args.verbose: + print("command\tDate\tExecution time\tVersion\tExit code\tException") + for h in history: + execution_time = time.strftime("%H:%M:%S", time.gmtime(h["execution_time"])) + d = h["date"].isoformat() + print( + f"{h['command']}\t{d}\t{execution_time}\t{h['version']}\t{h['exit_code']}\t{h['exception']}" + ) + pass + else: + for h in history: + print(h["command"]) + + def main() -> None: """ Main function for the MFA command line interface """ + + hooks = ExitHooks() + hooks.hook() + atexit.register(hooks.history_save_handler) + from colorama import init + + init() parser = create_parser() mp.freeze_support() args, unknown = parser.parse_known_args() @@ -776,7 +856,8 @@ def main() -> None: if short in unknown: print( f"Due to the number of options that `{short}` could refer to, it is not accepted. " - "Please specify the full argument" + "Please specify the full argument", + file=sys.stderr, ) sys.exit(1) try: @@ -786,7 +867,8 @@ def main() -> None: except ImportError: print( "There was an issue importing Pynini, please ensure that it is installed. If you are on Windows, " - "please use the Windows Subsystem for Linux to use g2p functionality." + "please use the Windows Subsystem for Linux to use g2p functionality.", + file=sys.stderr, ) sys.exit(1) if args.subcommand == "align": @@ -822,21 +904,7 @@ def main() -> None: global GLOBAL_CONFIG GLOBAL_CONFIG = load_global_config() elif args.subcommand == "history": - depth = args.depth - history = load_command_history()[-depth:] - if args.verbose: - print("command\tDate\tExecution time\tVersion\tExit code\tException") - for h in history: - execution_time = time.strftime("%H:%M:%S", time.gmtime(h["execution_time"])) - d = h["date"].isoformat() - print( - f"{h['command']}\t{d}\t{execution_time}\t{h['version']}\t{h['exit_code']}\t{h['exception']}" - ) - pass - else: - for h in history: - print(h["command"]) - + print_history(args) elif args.subcommand == "version": from montreal_forced_aligner.utils import get_mfa_version @@ -852,15 +920,15 @@ def main() -> None: except MFAError as e: if getattr(args, "debug", False): raise - print(e) + print(e, file=sys.stderr) sys.exit(1) if __name__ == "__main__": - hooks = ExitHooks() - hooks.hook() - atexit.register(history_save_handler) - from colorama import init + import warnings - init() + warnings.warn( + "Use 'python -m montreal_forced_aligner', not 'python -m montreal_forced_aligner.command_line.mfa'", + DeprecationWarning, + ) main() diff --git a/montreal_forced_aligner/command_line/model.py b/montreal_forced_aligner/command_line/model.py index 2d21c06d..d3496631 100644 --- a/montreal_forced_aligner/command_line/model.py +++ b/montreal_forced_aligner/command_line/model.py @@ -3,11 +3,11 @@ import os import shutil -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import requests -from montreal_forced_aligner.config import TEMP_DIR +from montreal_forced_aligner.config import get_temporary_directory from montreal_forced_aligner.exceptions import ( FileArgumentNotFoundError, ModelLoadError, @@ -17,11 +17,7 @@ ) from montreal_forced_aligner.helper import TerminalPrinter from montreal_forced_aligner.models import MODEL_TYPES, Archive -from montreal_forced_aligner.utils import ( - get_available_models, - get_pretrained_path, - guess_model_type, -) +from montreal_forced_aligner.utils import guess_model_type if TYPE_CHECKING: from argparse import Namespace @@ -38,7 +34,7 @@ ] -def list_downloadable_models(model_type: str) -> List[str]: +def list_downloadable_models(model_type: str) -> list[str]: """ Generate a list of models available for download @@ -49,13 +45,13 @@ def list_downloadable_models(model_type: str) -> List[str]: Returns ------- - List[str] + list[str] Names of models """ url = f"https://raw.githubusercontent.com/MontrealCorpusTools/mfa-models/main/{model_type}/index.txt" r = requests.get(url) if r.status_code == 404: - raise Exception('Could not find model type "{}"'.format(model_type)) + raise Exception(f'Could not find model type "{model_type}"') out = r.text return out.split("\n") @@ -75,9 +71,10 @@ def download_model(model_type: str, name: str) -> None: downloadable = "\n".join(f" - {x}" for x in list_downloadable_models(model_type)) print(f"Available models to download for {model_type}:\n\n{downloadable}") try: - mc = MODEL_TYPES[model_type] - extension = mc.extensions[0] - out_path = get_pretrained_path(model_type, name, enforce_existence=False) + model_class = MODEL_TYPES[model_type] + extension = model_class.extensions[0] + os.makedirs(model_class.pretrained_directory(), exist_ok=True) + out_path = model_class.get_pretrained_path(name, enforce_existence=False) except KeyError: raise NotImplementedError( f"{model_type} models are not currently supported for downloading" @@ -101,15 +98,16 @@ def list_model(model_type: Union[str, None]) -> None: printer = TerminalPrinter() if model_type is None: printer.print_information_line("Available models for use", "", level=0) - for mt in MODEL_TYPES: - names = get_available_models(mt) + for model_type, model_class in MODEL_TYPES.items(): + names = model_class.get_available_models() if names: - printer.print_information_line(mt, names, value_color="green") + printer.print_information_line(model_type, names, value_color="green") else: - printer.print_information_line(mt, "No models found", value_color="yellow") + printer.print_information_line(model_type, "No models found", value_color="yellow") else: printer.print_information_line(f"Available models for use {model_type}", "", level=0) - names = get_available_models(model_type) + model_class = MODEL_TYPES[model_type] + names = model_class.get_available_models() if names: for name in names: @@ -127,7 +125,7 @@ def inspect_model(path: str) -> None: path: str Path to model """ - working_dir = os.path.join(TEMP_DIR, "models", "inspect") + working_dir = os.path.join(get_temporary_directory(), "models", "inspect") ext = os.path.splitext(path)[1] model = None if ext == Archive.extensions[0]: # Figure out what kind of model it is @@ -154,10 +152,11 @@ def save_model(path: str, model_type: str, output_name: Optional[str]) -> None: Type of model """ model_name = os.path.splitext(os.path.basename(path))[0] + model_class = MODEL_TYPES[model_type] if output_name: - out_path = get_pretrained_path(model_type, output_name, enforce_existence=False) + out_path = model_class.get_pretrained_path(output_name, enforce_existence=False) else: - out_path = get_pretrained_path(model_type, model_name, enforce_existence=False) + out_path = model_class.get_pretrained_path(model_name, enforce_existence=False) shutil.copyfile(path, out_path) @@ -207,16 +206,17 @@ def validate_args(args: Namespace) -> None: possible_model_types = guess_model_type(args.name) if not possible_model_types: if args.model_type: - path = get_pretrained_path(args.model_type, args.name) + model_class = MODEL_TYPES[args.model_type] + path = model_class.get_pretrained_path(args.name) if path is None: raise PretrainedModelNotFoundError( - args.name, args.model_type, get_available_models(args.model_type) + args.name, args.model_type, model_class.get_available_models() ) else: found_model_types = [] path = None - for model_type in MODEL_TYPES: - p = get_pretrained_path(model_type, args.name) + for model_type, model_class in MODEL_TYPES.items(): + p = model_class.get_pretrained_path(args.name) if p is not None: path = p found_model_types.append(model_type) @@ -241,8 +241,6 @@ def run_model(args: Namespace) -> None: ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] - Parsed command line arguments to be passed to the configuration objects """ validate_args(args) if args.action == "download": diff --git a/montreal_forced_aligner/command_line/train_acoustic_model.py b/montreal_forced_aligner/command_line/train_acoustic_model.py index a02f7735..be337eb1 100644 --- a/montreal_forced_aligner/command_line/train_acoustic_model.py +++ b/montreal_forced_aligner/command_line/train_acoustic_model.py @@ -2,22 +2,11 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional -from montreal_forced_aligner.aligner import TrainableAligner +from montreal_forced_aligner.acoustic_modeling import TrainableAligner from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - load_basic_train, - load_command_configuration, - train_yaml_to_config, -) -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.utils import log_config, setup_logger if TYPE_CHECKING: from argparse import Namespace @@ -34,160 +23,35 @@ def train_acoustic_model(args: Namespace, unknown_args: Optional[list] = None) - ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "train_acoustic_model" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - if args.config_path: - train_config, align_config, dictionary_config = train_yaml_to_config(args.config_path) - else: - train_config, align_config, dictionary_config = load_basic_train() - train_config.use_mp = not args.disable_mp - align_config.use_mp = not args.disable_mp - align_config.debug = args.debug - align_config.overwrite = args.overwrite - align_config.cleanup_textgrids = not args.disable_textgrid_cleanup - if unknown_args: - train_config.update_from_unknown_args(unknown_args) - align_config.update_from_unknown_args(unknown_args) - train_config.update_from_align(align_config) - conf_path = os.path.join(data_directory, "config.yml") - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - logger.debug("TRAIN CONFIG:") - log_config(logger, train_config) - logger.debug("ALIGN CONFIG:") - log_config(logger, align_config) - if args.debug: - logger.warning("Running in DEBUG mode, may have impact on performance and disk usage.") - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": time.time(), - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - "dictionary_path": args.dictionary_path, - }, + + trainer = TrainableAligner( + corpus_directory=args.corpus_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **TrainableAligner.parse_parameters(args.config_path, args, unknown_args), ) - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - or conf["dictionary_path"] != args.dictionary_path - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["dictionary_path"] != args.dictionary_path: - logger.debug( - f"Previous run used dictionary path {conf['dictionary_path']} " - f"(new run: {args.dictionary_path})" - ) - os.makedirs(data_directory, exist_ok=True) - model_directory = os.path.join(data_directory, "acoustic_models") - audio_dir = None - if args.audio_directory: - audio_dir = args.audio_directory try: - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - num_jobs=getattr(args, "num_jobs", 3), - sample_rate=align_config.feature_config.sample_frequency, - debug=getattr(args, "debug", False), - logger=logger, - use_mp=align_config.use_mp, - audio_directory=audio_dir, - ) - logger.info(corpus.speaker_utterance_info()) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - word_set=corpus.word_set, - logger=logger, - ) - utt_oov_path = os.path.join(corpus.split_directory, "utterance_oovs.txt") - if os.path.exists(utt_oov_path): - shutil.copy(utt_oov_path, args.output_directory) - oov_path = os.path.join(corpus.split_directory, "oovs_found.txt") - if os.path.exists(oov_path): - shutil.copy(oov_path, args.output_directory) - a = TrainableAligner( - corpus, - dictionary, - train_config, - align_config, - temp_directory=data_directory, - logger=logger, - debug=getattr(args, "debug", False), - ) - a.verbose = args.verbose - begin = time.time() generate_final_alignments = True if args.output_directory is None: generate_final_alignments = False else: os.makedirs(args.output_directory, exist_ok=True) - a.train(generate_final_alignments) - logger.debug(f"Training took {time.time() - begin} seconds") + trainer.train(generate_final_alignments) if args.output_model_path is not None: - a.save(args.output_model_path, root_directory=model_directory) + trainer.export_model(args.output_model_path) if args.output_directory is not None: - a.export_textgrids(args.output_directory) - logger.info("All done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") + trainer.export_files(args.output_directory) except Exception: - conf["dirty"] = True + trainer.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + trainer.cleanup() def validate_args(args: Namespace) -> None: @@ -244,7 +108,7 @@ def run_train_acoustic_model(args: Namespace, unknown_args: Optional[list] = Non ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/train_dictionary.py b/montreal_forced_aligner/command_line/train_dictionary.py index af437b59..0d3ad356 100644 --- a/montreal_forced_aligner/command_line/train_dictionary.py +++ b/montreal_forced_aligner/command_line/train_dictionary.py @@ -2,23 +2,11 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional -from montreal_forced_aligner.aligner import PretrainedAligner +from montreal_forced_aligner.alignment.pretrained import DictionaryTrainer from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - align_yaml_to_config, - load_basic_align, - load_command_configuration, -) -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import AcousticModel -from montreal_forced_aligner.utils import log_config, setup_logger if TYPE_CHECKING: from argparse import Namespace @@ -35,144 +23,25 @@ def train_dictionary(args: Namespace, unknown_args: Optional[list] = None) -> No ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "train_dictionary" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - conf_path = os.path.join(data_directory, "config.yml") - if args.config_path: - align_config, dictionary_config = align_yaml_to_config(args.config_path) - else: - align_config, dictionary_config = load_basic_align() - align_config.use_mp = not args.disable_mp - align_config.overwrite = args.overwrite - align_config.debug = args.debug - dictionary_config.debug = args.debug - if unknown_args: - align_config.update_from_unknown_args(unknown_args) - dictionary_config.update_from_unknown_args(unknown_args) - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - logger.debug("ALIGN CONFIG:") - log_config(logger, align_config) - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": time.time(), - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - "dictionary_path": args.dictionary_path, - "acoustic_model_path": args.acoustic_model_path, - }, + aligner = DictionaryTrainer( + acoustic_model_path=args.acoustic_model_path, + corpus_directory=args.corpus_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **DictionaryTrainer.parse_parameters(args.config_path, args, unknown_args), ) - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - or conf["dictionary_path"] != args.dictionary_path - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["dictionary_path"] != args.dictionary_path: - logger.debug( - f"Previous run used dictionary path {conf['dictionary_path']} " - f"(new run: {args.dictionary_path})" - ) - if conf["acoustic_model_path"] != args.acoustic_model_path: - logger.debug( - f"Previous run used acoustic model path {conf['acoustic_model_path']} " - f"(new run: {args.acoustic_model_path})" - ) - - os.makedirs(data_directory, exist_ok=True) - try: - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - num_jobs=args.num_jobs, - sample_rate=align_config.feature_config.sample_frequency, - use_mp=align_config.use_mp, - logger=logger, - ) - logger.info(corpus.speaker_utterance_info()) - acoustic_model = AcousticModel(args.acoustic_model_path) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - word_set=corpus.word_set, - logger=logger, - ) - acoustic_model.validate(dictionary) - - begin = time.time() - a = PretrainedAligner( - corpus, - dictionary, - acoustic_model, - align_config, - temp_directory=data_directory, - debug=getattr(args, "debug", False), - logger=logger, - ) - logger.debug(f"Setup pretrained aligner in {time.time() - begin} seconds") - a.verbose = args.verbose - - begin = time.time() - a.align() - logger.debug(f"Performed alignment in {time.time() - begin} seconds") - a.generate_pronunciations(args.output_directory) - logger.info(f"Done! Everything took {time.time() - all_begin} seconds") + try: + aligner.align() + aligner.export_lexicons(args.output_directory) except Exception: - conf["dirty"] = True + aligner.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + aligner.cleanup() def validate_args(args: Namespace) -> None: @@ -214,7 +83,7 @@ def run_train_dictionary(args: Namespace, unknown: Optional[list] = None) -> Non ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/train_g2p.py b/montreal_forced_aligner/command_line/train_g2p.py index a3cef78e..92cb9cbb 100644 --- a/montreal_forced_aligner/command_line/train_g2p.py +++ b/montreal_forced_aligner/command_line/train_g2p.py @@ -1,18 +1,10 @@ """Command line functions for training G2P models""" from __future__ import annotations -import os -import shutil from typing import TYPE_CHECKING, Optional from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import TEMP_DIR -from montreal_forced_aligner.config.train_g2p_config import ( - load_basic_train_g2p_config, - train_g2p_yaml_to_config, -) -from montreal_forced_aligner.dictionary import PronunciationDictionary -from montreal_forced_aligner.g2p.trainer import PyniniTrainer as Trainer +from montreal_forced_aligner.g2p.trainer import PyniniTrainer if TYPE_CHECKING: from argparse import Namespace @@ -29,36 +21,26 @@ def train_g2p(args: Namespace, unknown_args: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - if args.clean: - shutil.rmtree(os.path.join(temp_dir, "G2P"), ignore_errors=True) - shutil.rmtree(os.path.join(temp_dir, "models", "G2P"), ignore_errors=True) - if args.config_path: - train_config, dictionary_config = train_g2p_yaml_to_config(args.config_path) - else: - train_config, dictionary_config = load_basic_train_g2p_config() - train_config.use_mp = not args.disable_mp - if unknown_args: - train_config.update_from_unknown_args(unknown_args) - dictionary = PronunciationDictionary(args.dictionary_path, "", dictionary_config) - t = Trainer( - dictionary, - args.output_model_path, - temp_directory=temp_dir, - train_config=train_config, - num_jobs=args.num_jobs, - verbose=args.verbose, + + trainer = PyniniTrainer( + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **PyniniTrainer.parse_parameters(args.config_path, args, unknown_args) ) - if args.validate: - t.validate() - else: - t.train() + + try: + trainer.setup() + trainer.train() + trainer.export_model(args.output_model_path) + + except Exception: + trainer.dirty = True + raise + finally: + trainer.cleanup() def validate_args(args: Namespace) -> None: @@ -86,7 +68,7 @@ def run_train_g2p(args: Namespace, unknown: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/train_ivector_extractor.py b/montreal_forced_aligner/command_line/train_ivector_extractor.py index 1335b762..252bd443 100644 --- a/montreal_forced_aligner/command_line/train_ivector_extractor.py +++ b/montreal_forced_aligner/command_line/train_ivector_extractor.py @@ -2,23 +2,10 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional -from montreal_forced_aligner.aligner import PretrainedAligner -from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - load_basic_train_ivector, - load_command_configuration, - train_yaml_to_config, -) -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import AcousticModel -from montreal_forced_aligner.utils import log_config, setup_logger +from montreal_forced_aligner.ivector.trainer import TrainableIvectorExtractor if TYPE_CHECKING: from argparse import Namespace @@ -34,154 +21,26 @@ def train_ivector(args: Namespace, unknown_args: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "train_ivector" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - if args.config_path: - train_config, align_config, dictionary_config = train_yaml_to_config(args.config_path) - else: - train_config, align_config, dictionary_config = load_basic_train_ivector() - if unknown_args: - train_config.update_from_unknown_args(unknown_args) - align_config.update_from_unknown_args(unknown_args) - train_config.use_mp = not args.disable_mp - align_config.use_mp = not args.disable_mp - conf_path = os.path.join(data_directory, "config.yml") - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - logger.debug("TRAIN CONFIG:") - log_config(logger, train_config) - logger.debug("ALIGN CONFIG:") - log_config(logger, align_config) - - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": all_begin, - "version": get_mfa_version(), - "type": command, - "corpus_directory": args.corpus_directory, - "dictionary_path": args.dictionary_path, - "acoustic_model_path": args.acoustic_model_path, - }, + + trainer = TrainableIvectorExtractor( + corpus_directory=args.corpus_directory, + temporary_directory=args.temporary_directory, + **TrainableIvectorExtractor.parse_parameters(args.config_path, args, unknown_args), ) - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - or conf["dictionary_path"] != args.dictionary_path - or conf["acoustic_model_path"] != args.acoustic_model_path - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["dictionary_path"] != args.dictionary_path: - logger.debug( - f"Previous run used dictionary path {conf['dictionary_path']} " - f"(new run: {args.dictionary_path})" - ) - if conf["acoustic_model_path"] != args.acoustic_model_path: - logger.debug( - f"Previous run used acoustic model path {conf['acoustic_model_path']} " - f"(new run: {args.acoustic_model_path})" - ) - os.makedirs(data_directory, exist_ok=True) - model_directory = os.path.join(data_directory, "acoustic_models") try: - begin = time.time() - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - num_jobs=args.num_jobs, - sample_rate=align_config.feature_config.sample_frequency, - debug=getattr(args, "debug", False), - logger=logger, - use_mp=align_config.use_mp, - ) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - word_set=corpus.word_set, - logger=logger, - ) - acoustic_model = AcousticModel(args.acoustic_model_path, root_directory=model_directory) - acoustic_model.log_details(logger) - acoustic_model.validate(dictionary) - a = PretrainedAligner( - corpus, - dictionary, - acoustic_model, - align_config, - temp_directory=data_directory, - logger=logger, - ) - logger.debug(f"Setup pretrained aligner in {time.time() - begin} seconds") - a.verbose = args.verbose - begin = time.time() - a.align() - logger.debug(f"Performed alignment in {time.time() - begin} seconds") - for identifier, trainer in train_config.items(): - trainer.logger = logger - if identifier != "ivector": - continue - begin = time.time() - trainer.init_training(identifier, data_directory, corpus, dictionary, a) - trainer.train() - logger.debug(f"Training took {time.time() - begin} seconds") - trainer.save(args.output_model_path, root_directory=model_directory) - - logger.info("All done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") - except Exception as e: - conf["dirty"] = True - raise e + + trainer.train() + trainer.export_model(args.output_model_path) + + except Exception: + trainer.dirty = True + raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + trainer.cleanup() def validate_args(args: Namespace) -> None: @@ -215,9 +74,6 @@ def validate_args(args: Namespace) -> None: ) ) - args.dictionary_path = validate_model_arg(args.dictionary_path, "dictionary") - args.acoustic_model_path = validate_model_arg(args.acoustic_model_path, "acoustic") - def run_train_ivector_extractor(args: Namespace, unknown: Optional[list] = None) -> None: """ @@ -227,7 +83,7 @@ def run_train_ivector_extractor(args: Namespace, unknown: Optional[list] = None) ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/train_lm.py b/montreal_forced_aligner/command_line/train_lm.py index 48798b39..5393fea6 100644 --- a/montreal_forced_aligner/command_line/train_lm.py +++ b/montreal_forced_aligner/command_line/train_lm.py @@ -2,17 +2,15 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import TEMP_DIR, load_basic_train_lm, train_lm_yaml_to_config -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import PronunciationDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.lm.trainer import LmTrainer -from montreal_forced_aligner.utils import setup_logger +from montreal_forced_aligner.language_modeling.trainer import ( + LmArpaTrainer, + LmCorpusTrainer, + LmDictionaryCorpusTrainer, +) if TYPE_CHECKING: from argparse import Namespace @@ -28,79 +26,41 @@ def train_lm(args: Namespace, unknown_args: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - command = "train_lm" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - if args.config_path: - train_config, dictionary_config = train_lm_yaml_to_config(args.config_path) - else: - train_config, dictionary_config = load_basic_train_lm() - train_config.use_mp = not args.disable_mp - if unknown_args: - train_config.update_from_unknown_args(unknown_args) - corpus_name = os.path.basename(args.source_path) - if corpus_name == "": - args.source_path = os.path.dirname(args.source_path) - corpus_name = os.path.basename(args.source_path) - source = args.source_path - dictionary = None - if args.source_path.lower().endswith(".arpa"): - corpus_name = os.path.splitext(corpus_name)[0] - data_directory = os.path.join(temp_dir, corpus_name) - else: - data_directory = os.path.join(temp_dir, corpus_name) - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) if not args.source_path.lower().endswith(".arpa"): - source = Corpus( - args.source_path, - data_directory, - num_jobs=args.num_jobs, - use_mp=train_config.use_mp, - parse_text_only_files=True, - debug=args.debug, - ) - if args.dictionary_path: - dictionary = PronunciationDictionary( - args.dictionary_path, data_directory, dictionary_config, word_set=source.word_set + if not args.dictionary_path: + trainer = LmCorpusTrainer( + corpus_directory=args.source_path, + temporary_directory=args.temporary_directory, + **LmCorpusTrainer.parse_parameters(args.config_path, args, unknown_args), ) - dictionary.generate_mappings() else: - dictionary = None - trainer = LmTrainer( - source, - train_config, - args.output_model_path, - dictionary=dictionary, - temp_directory=data_directory, - supplemental_model_path=args.model_path, - supplemental_model_weight=args.model_weight, - debug=args.debug, - logger=logger, - ) - begin = time.time() - trainer.train() - logger.debug(f"Training took {time.time() - begin} seconds") - - logger.info("All done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) + trainer = LmDictionaryCorpusTrainer( + corpus_directory=args.source_path, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **LmDictionaryCorpusTrainer.parse_parameters(args.config_path, args, unknown_args), + ) + else: + trainer = LmArpaTrainer( + arpa_path=args.source_path, + temporary_directory=args.temporary_directory, + **LmArpaTrainer.parse_parameters(args.config_path, args, unknown_args), + ) + + try: + trainer.setup() + trainer.train() + trainer.export_model(args.output_model_path) + + except Exception: + trainer.dirty = True + raise + finally: + trainer.cleanup() def validate_args(args: Namespace) -> None: @@ -146,7 +106,7 @@ def run_train_lm(args: Namespace, unknown: Optional[list] = None) -> None: ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/transcribe.py b/montreal_forced_aligner/command_line/transcribe.py index b94ddf5c..7cd7a4b1 100644 --- a/montreal_forced_aligner/command_line/transcribe.py +++ b/montreal_forced_aligner/command_line/transcribe.py @@ -2,23 +2,11 @@ from __future__ import annotations import os -import shutil -import time from typing import TYPE_CHECKING, Optional from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import ( - TEMP_DIR, - load_basic_transcribe, - load_command_configuration, - transcribe_yaml_to_config, -) -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import AcousticModel, LanguageModel -from montreal_forced_aligner.transcriber import Transcriber -from montreal_forced_aligner.utils import log_config, setup_logger +from montreal_forced_aligner.transcription import Transcriber if TYPE_CHECKING: from argparse import Namespace @@ -35,172 +23,27 @@ def transcribe_corpus(args: Namespace, unknown_args: Optional[list] = None) -> N ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - from montreal_forced_aligner.utils import get_mfa_version - - command = "transcribe" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - if args.config_path: - transcribe_config, dictionary_config = transcribe_yaml_to_config(args.config_path) - else: - transcribe_config, dictionary_config = load_basic_transcribe() - transcribe_config.use_mp = not args.disable_mp - transcribe_config.overwrite = args.overwrite - if unknown_args: - transcribe_config.update_from_unknown_args(unknown_args) - data_directory = os.path.join(temp_dir, corpus_name) - if getattr(args, "clean", False) and os.path.exists(data_directory): - print("Cleaning old directory!") - shutil.rmtree(data_directory, ignore_errors=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - logger.debug("TRANSCRIBE CONFIG:") - log_config(logger, transcribe_config) - os.makedirs(data_directory, exist_ok=True) - model_directory = os.path.join(data_directory, "acoustic_models") - os.makedirs(args.output_directory, exist_ok=True) - os.makedirs(model_directory, exist_ok=True) - conf_path = os.path.join(data_directory, "config.yml") - conf = load_command_configuration( - conf_path, - { - "dirty": False, - "begin": time.time(), - "version": get_mfa_version(), - "type": "transcribe", - "corpus_directory": args.corpus_directory, - "dictionary_path": args.dictionary_path, - "acoustic_model_path": args.acoustic_model_path, - "language_model_path": args.language_model_path, - }, + transcriber = Transcriber( + acoustic_model_path=args.acoustic_model_path, + language_model_path=args.language_model_path, + corpus_directory=args.corpus_directory, + audio_directory=args.audio_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **Transcriber.parse_parameters(args.config_path, args, unknown_args), ) - - if ( - conf["dirty"] - or conf["type"] != command - or conf["corpus_directory"] != args.corpus_directory - or conf["version"] != get_mfa_version() - or conf["dictionary_path"] != args.dictionary_path - or conf["language_model_path"] != args.language_model_path - or conf["acoustic_model_path"] != args.acoustic_model_path - ): - logger.warning( - "WARNING: Using old temp directory, this might not be ideal for you, use the --clean flag to ensure no " - "weird behavior for previous versions of the temporary directory." - ) - if conf["dirty"]: - logger.debug("Previous run ended in an error (maybe ctrl-c?)") - if conf["type"] != command: - logger.debug( - f"Previous run was a different subcommand than {command} (was {conf['type']})" - ) - if conf["corpus_directory"] != args.corpus_directory: - logger.debug( - "Previous run used source directory " - f"path {conf['corpus_directory']} (new run: {args.corpus_directory})" - ) - if conf["version"] != get_mfa_version(): - logger.debug( - f"Previous run was on {conf['version']} version (new run: {get_mfa_version()})" - ) - if conf["dictionary_path"] != args.dictionary_path: - logger.debug( - f"Previous run used dictionary path {conf['dictionary_path']} " - f"(new run: {args.dictionary_path})" - ) - if conf["acoustic_model_path"] != args.acoustic_model_path: - logger.debug( - f"Previous run used acoustic model path {conf['acoustic_model_path']} " - f"(new run: {args.acoustic_model_path})" - ) - if conf["language_model_path"] != args.language_model_path: - logger.debug( - f"Previous run used language model path {conf['language_model_path']} " - f"(new run: {args.language_model_path})" - ) - audio_dir = None - if args.audio_directory: - audio_dir = args.audio_directory try: - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - sample_rate=transcribe_config.feature_config.sample_frequency, - num_jobs=args.num_jobs, - use_mp=transcribe_config.use_mp, - audio_directory=audio_dir, - ignore_speakers=transcribe_config.ignore_speakers, - ) - acoustic_model = AcousticModel(args.acoustic_model_path, root_directory=model_directory) - dictionary_config.update(acoustic_model.meta) - acoustic_model.log_details(logger) - if args.language_model_path.endswith(".arpa"): - alternative_name = os.path.splitext(args.language_model_path)[0] + ".zip" - logger.warning( - f"Using a plain .arpa model requires generating pruned versions of it to decode in a reasonable " - f"amount of time. If you'd like to generate a reusable language model, consider running " - f"`mfa train_lm {args.language_model_path} {alternative_name}`." - ) - language_model = LanguageModel(args.language_model_path, root_directory=data_directory) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - logger=logger, - ) - - acoustic_model.validate(dictionary) - begin = time.time() - t = Transcriber( - corpus, - dictionary, - acoustic_model, - language_model, - transcribe_config, - temp_directory=data_directory, - debug=getattr(args, "debug", False), - evaluation_mode=args.evaluate, - logger=logger, - ) - logger.debug(f"Setup transcriber in {time.time() - begin} seconds") - - begin = time.time() - t.transcribe() - logger.debug(f"Performed transcribing in {time.time() - begin} seconds") - if args.evaluate: - t.evaluate() - best_config_path = os.path.join(args.output_directory, "best_transcribe_config.yaml") - t.transcribe_config.save(best_config_path) - t.export_transcriptions(args.output_directory) - else: - begin = time.time() - t.export_transcriptions(args.output_directory) - logger.debug(f"Exported transcriptions in {time.time() - begin} seconds") - logger.info(f"Done! Everything took {time.time() - all_begin} seconds") + transcriber.setup() + transcriber.transcribe() + transcriber.export_files(args.output_directory) except Exception: - conf["dirty"] = True + transcriber.dirty = True raise finally: - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) - conf.save(conf_path) + transcriber.cleanup() def validate_args(args: Namespace) -> None: @@ -247,7 +90,7 @@ def run_transcribe_corpus(args: Namespace, unknown: Optional[list] = None) -> No ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/command_line/utils.py b/montreal_forced_aligner/command_line/utils.py index b7a3f387..53591cf6 100644 --- a/montreal_forced_aligner/command_line/utils.py +++ b/montreal_forced_aligner/command_line/utils.py @@ -13,7 +13,6 @@ PretrainedModelNotFoundError, ) from ..models import MODEL_TYPES -from ..utils import get_available_models, get_pretrained_path __all__ = ["validate_model_arg"] @@ -49,10 +48,11 @@ def validate_model_arg(name: str, model_type: str) -> str: """ if model_type not in MODEL_TYPES: raise ModelTypeNotSupportedError(model_type, MODEL_TYPES) - available_models = get_available_models(model_type) + model_class = MODEL_TYPES[model_type] + available_models = model_class.get_available_models() model_class = MODEL_TYPES[model_type] if name in available_models: - name = get_pretrained_path(model_type, name) + name = model_class.get_pretrained_path(name) elif model_class.valid_extension(name): if not os.path.exists(name): raise FileArgumentNotFoundError(name) diff --git a/montreal_forced_aligner/command_line/validate.py b/montreal_forced_aligner/command_line/validate.py index e161e95c..91522bc9 100644 --- a/montreal_forced_aligner/command_line/validate.py +++ b/montreal_forced_aligner/command_line/validate.py @@ -2,19 +2,11 @@ from __future__ import annotations import os -import shutil -import time -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from montreal_forced_aligner.command_line.utils import validate_model_arg -from montreal_forced_aligner.config import TEMP_DIR -from montreal_forced_aligner.config.dictionary_config import DictionaryConfig -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary from montreal_forced_aligner.exceptions import ArgumentError -from montreal_forced_aligner.models import AcousticModel -from montreal_forced_aligner.utils import setup_logger -from montreal_forced_aligner.validator import CorpusValidator +from montreal_forced_aligner.validator import PretrainedValidator, TrainingValidator if TYPE_CHECKING: from argparse import Namespace @@ -23,7 +15,7 @@ __all__ = ["validate_corpus", "validate_args", "run_validate_corpus"] -def validate_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) -> None: +def validate_corpus(args: Namespace, unknown_args: Optional[list[str]] = None) -> None: """ Run the validation command @@ -31,72 +23,33 @@ def validate_corpus(args: Namespace, unknown_args: Optional[List[str]] = None) - ---------- args: :class:`~argparse.Namespace` Command line arguments - unknown_args: List[str] + unknown_args: list[str] Optional arguments that will be passed to configuration objects """ - command = "validate" - all_begin = time.time() - if not args.temp_directory: - temp_dir = TEMP_DIR - else: - temp_dir = os.path.expanduser(args.temp_directory) - corpus_name = os.path.basename(args.corpus_directory) - if corpus_name == "": - args.corpus_directory = os.path.dirname(args.corpus_directory) - corpus_name = os.path.basename(args.corpus_directory) - data_directory = os.path.join(temp_dir, corpus_name) - shutil.rmtree(data_directory, ignore_errors=True) - - os.makedirs(data_directory, exist_ok=True) - model_directory = os.path.join(data_directory, "acoustic_models") - os.makedirs(model_directory, exist_ok=True) - if getattr(args, "verbose", False): - log_level = "debug" - else: - log_level = "info" - logger = setup_logger(command, data_directory, console_level=log_level) - dictionary_config = DictionaryConfig() - acoustic_model = None if args.acoustic_model_path: - acoustic_model = AcousticModel(args.acoustic_model_path, root_directory=model_directory) - acoustic_model.log_details(logger) - dictionary_config.update(acoustic_model.meta) - dictionary = MultispeakerDictionary( - args.dictionary_path, - data_directory, - dictionary_config, - logger=logger, - ) - if acoustic_model: - acoustic_model.validate(dictionary) - - corpus = Corpus( - args.corpus_directory, - data_directory, - dictionary_config, - speaker_characters=args.speaker_characters, - num_jobs=getattr(args, "num_jobs", 3), - logger=logger, - use_mp=not args.disable_mp, - ) - a = CorpusValidator( - corpus, - dictionary, - temp_directory=data_directory, - ignore_acoustics=getattr(args, "ignore_acoustics", False), - test_transcriptions=getattr(args, "test_transcriptions", False), - use_mp=not args.disable_mp, - logger=logger, - ) - begin = time.time() - a.validate() - logger.debug(f"Validation took {time.time() - begin} seconds") - logger.info("All done!") - logger.debug(f"Done! Everything took {time.time() - all_begin} seconds") - handlers = logger.handlers[:] - for handler in handlers: - handler.close() - logger.removeHandler(handler) + validator = PretrainedValidator( + acoustic_model_path=args.acoustic_model_path, + corpus_directory=args.corpus_directory, + audio_directory=args.audio_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **PretrainedValidator.parse_parameters(args.config_path, args, unknown_args), + ) + else: + validator = TrainingValidator( + corpus_directory=args.corpus_directory, + audio_directory=args.audio_directory, + dictionary_path=args.dictionary_path, + temporary_directory=args.temporary_directory, + **TrainingValidator.parse_parameters(args.config_path, args, unknown_args), + ) + try: + validator.validate() + except Exception: + validator.dirty = True + raise + finally: + validator.cleanup() def validate_args(args: Namespace) -> None: @@ -133,7 +86,7 @@ def validate_args(args: Namespace) -> None: args.acoustic_model_path = validate_model_arg(args.acoustic_model_path, "acoustic") -def run_validate_corpus(args: Namespace, unknown: Optional[List[str]] = None) -> None: +def run_validate_corpus(args: Namespace, unknown: Optional[list[str]] = None) -> None: """ Wrapper function for running corpus validation @@ -141,7 +94,7 @@ def run_validate_corpus(args: Namespace, unknown: Optional[List[str]] = None) -> ---------- args: :class:`~argparse.Namespace` Parsed command line arguments - unknown: List[str] + unknown: list[str] Parsed command line arguments to be passed to the configuration objects """ validate_args(args) diff --git a/montreal_forced_aligner/config/__init__.py b/montreal_forced_aligner/config.py similarity index 62% rename from montreal_forced_aligner/config/__init__.py rename to montreal_forced_aligner/config.py index c5355586..9634945d 100644 --- a/montreal_forced_aligner/config/__init__.py +++ b/montreal_forced_aligner/config.py @@ -1,12 +1,14 @@ """ -Configuration classes -===================== - +MFA configuration +================= """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List +import re +from typing import TYPE_CHECKING, Any + +from montreal_forced_aligner.exceptions import RootDirectoryError if TYPE_CHECKING: from argparse import Namespace @@ -15,53 +17,7 @@ import yaml -from .align_config import AlignConfig, align_yaml_to_config, load_basic_align # noqa -from .base_config import BaseConfig -from .command_config import CommandConfig, load_command_configuration # noqa -from .dictionary_config import DictionaryConfig # noqa -from .feature_config import FeatureConfig # noqa -from .g2p_config import G2PConfig, g2p_yaml_to_config, load_basic_g2p_config # noqa -from .segmentation_config import ( # noqa - SegmentationConfig, - load_basic_segmentation, - segmentation_yaml_to_config, -) -from .speaker_classification_config import ( # noqa - SpeakerClassificationConfig, - classification_yaml_to_config, - load_basic_classification, -) -from .train_config import ( # noqa - TrainingConfig, - load_basic_train, - load_basic_train_ivector, - load_test_config, - train_yaml_to_config, -) -from .train_g2p_config import ( # noqa - TrainG2PConfig, - load_basic_train_g2p_config, - train_g2p_yaml_to_config, -) -from .train_lm_config import TrainLMConfig, load_basic_train_lm, train_lm_yaml_to_config # noqa -from .transcribe_config import ( # noqa - TranscribeConfig, - load_basic_transcribe, - transcribe_yaml_to_config, -) - __all__ = [ - "TEMP_DIR", - "align_config", - "base_config", - "command_config", - "dictionary_config", - "feature_config", - "segmentation_config", - "speaker_classification_config", - "train_config", - "train_lm_config", - "transcribe_config", "generate_config_path", "generate_command_history_path", "load_command_history", @@ -72,21 +28,28 @@ "BLAS_THREADS", ] -BaseConfig.__module__ = "montreal_forced_aligner.config" -AlignConfig.__module__ = "montreal_forced_aligner.config" -CommandConfig.__module__ = "montreal_forced_aligner.config" -FeatureConfig.__module__ = "montreal_forced_aligner.config" -DictionaryConfig.__module__ = "montreal_forced_aligner.config" -SegmentationConfig.__module__ = "montreal_forced_aligner.config" -SpeakerClassificationConfig.__module__ = "montreal_forced_aligner.config" -TrainingConfig.__module__ = "montreal_forced_aligner.config" -TrainLMConfig.__module__ = "montreal_forced_aligner.config" -TrainG2PConfig.__module__ = "montreal_forced_aligner.config" -G2PConfig.__module__ = "montreal_forced_aligner.config" -TranscribeConfig.__module__ = "montreal_forced_aligner.config" +MFA_ROOT_ENVIRONMENT_VARIABLE = "MFA_ROOT_DIR" -TEMP_DIR = os.path.expanduser("~/Documents/MFA") +def get_temporary_directory(): + """ + Get the root temporary directory for MFA + + Returns + ------- + str + Root temporary directory + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.RootDirectoryError` + """ + TEMP_DIR = os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, os.path.expanduser("~/Documents/MFA")) + try: + os.makedirs(TEMP_DIR, exist_ok=True) + except OSError: + raise RootDirectoryError(TEMP_DIR, MFA_ROOT_ENVIRONMENT_VARIABLE) + return TEMP_DIR def generate_config_path() -> str: @@ -98,7 +61,7 @@ def generate_config_path() -> str: str Full path to configuration yaml """ - return os.path.join(TEMP_DIR, "global_config.yaml") + return os.path.join(get_temporary_directory(), "global_config.yaml") def generate_command_history_path() -> str: @@ -110,16 +73,16 @@ def generate_command_history_path() -> str: str Full path to history file """ - return os.path.join(TEMP_DIR, "command_history.yaml") + return os.path.join(get_temporary_directory(), "command_history.yaml") -def load_command_history() -> List[str]: +def load_command_history() -> list[dict[str, Any]]: """ Load command history for MFA Returns ------- - List + list[dict[str, Any]] List of commands previously run """ path = generate_command_history_path() @@ -127,16 +90,18 @@ def load_command_history() -> List[str]: if os.path.exists(path): with open(path, "r", encoding="utf8") as f: history = yaml.safe_load(f) + for h in history: + h["command"] = re.sub(r"^\S+.py ", "mfa ", h["command"]) return history -def update_command_history(command_data: dict) -> None: +def update_command_history(command_data: dict[str, Any]) -> None: """ Update command history with most recent command Parameters ---------- - command_data: dict + command_data: dict[str, Any] Current command metadata """ try: @@ -173,7 +138,7 @@ def update_global_config(args: Namespace) -> None: "num_jobs": 3, "blas_num_threads": 1, "use_mp": True, - "temp_directory": TEMP_DIR, + "temporary_directory": get_temporary_directory(), } if os.path.exists(global_configuration_file): with open(global_configuration_file, "r", encoding="utf8") as f: @@ -213,19 +178,19 @@ def update_global_config(args: Namespace) -> None: default_config["terminal_width"] = args.terminal_width if args.blas_num_threads and args.blas_num_threads > 0: default_config["blas_num_threads"] = args.blas_num_threads - if args.temp_directory: - default_config["temp_directory"] = args.temp_directory + if args.temporary_directory: + default_config["temporary_directory"] = args.temporary_directory with open(global_configuration_file, "w", encoding="utf8") as f: yaml.dump(default_config, f) -def load_global_config() -> Dict[str, Any]: +def load_global_config() -> dict[str, Any]: """ Load the global MFA configuration Returns ------- - Dict + dict[str, Any] Global configuration """ global_configuration_file = generate_config_path() @@ -240,12 +205,14 @@ def load_global_config() -> Dict[str, Any]: "num_jobs": 3, "blas_num_threads": 1, "use_mp": True, - "temp_directory": TEMP_DIR, + "temporary_directory": get_temporary_directory(), } if os.path.exists(global_configuration_file): with open(global_configuration_file, "r", encoding="utf8") as f: data = yaml.safe_load(f) default_config.update(data) + if "temp_directory" in default_config: + default_config["temporary_directory"] = default_config["temp_directory"] return default_config diff --git a/montreal_forced_aligner/config/adapt_nosat.yaml b/montreal_forced_aligner/config/adapt_nosat.yaml deleted file mode 100644 index 6d397f26..00000000 --- a/montreal_forced_aligner/config/adapt_nosat.yaml +++ /dev/null @@ -1,17 +0,0 @@ -beam: 10 -retry_beam: 40 - -features: - type: "mfcc" - use_energy: false - frame_shift: 10 - snip_edges: true - -training: - - - triphone: - num_iterations: 3 - num_leaves: 4200 - max_gaussians: 40000 - power: 0.25 - boost_silence: 1.25 diff --git a/montreal_forced_aligner/config/adapt_sat.yaml b/montreal_forced_aligner/config/adapt_sat.yaml deleted file mode 100644 index 1480dd61..00000000 --- a/montreal_forced_aligner/config/adapt_sat.yaml +++ /dev/null @@ -1,18 +0,0 @@ -beam: 10 -retry_beam: 40 - -features: - type: "mfcc" - use_energy: false - frame_shift: 10 - snip_edges: true - -training: - - - sat: - num_iterations: 3 - num_leaves: 4200 - max_gaussians: 40000 - power: 0.2 - silence_weight: 0.0 - fmllr_update_type: "full" diff --git a/montreal_forced_aligner/config/align_config.py b/montreal_forced_aligner/config/align_config.py deleted file mode 100644 index e629e4ff..00000000 --- a/montreal_forced_aligner/config/align_config.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Class definitions for configuring aligning""" -from __future__ import annotations - -import os -from typing import TYPE_CHECKING, Collection, Tuple - -import yaml - -from ..exceptions import ConfigError -from .base_config import BaseConfig -from .dictionary_config import DictionaryConfig -from .feature_config import FeatureConfig - -if TYPE_CHECKING: - from argparse import Namespace - - from ..abc import MetaDict - -__all__ = ["AlignConfig", "align_yaml_to_config", "load_basic_align"] - - -class AlignConfig(BaseConfig): - """ - Configuration object for alignment - - Attributes - ---------- - transition_scale : float - Transition scale, defaults to 1.0 - acoustic_scale : float - Acoustic scale, defaults to 0.1 - self_loop_scale : float - Self-loop scale, defaults to 0.1 - disable_sat : bool - Flag for disabling speaker adaptation, defaults to False - feature_config : :class:`~montreal_forced_aligner.config.FeatureConfig` - Configuration object for feature generation - boost_silence : float - Factor to boost silence probabilities, 1.0 is no boost or reduction - beam : int - Size of the beam to use in decoding, defaults to 10 - retry_beam : int - Size of the beam to use in decoding if it fails with the initial beam width, defaults to 40 - data_directory : str - Path to save feature files - fmllr_update_type : str - Type of update for fMLLR, defaults to full - use_mp : bool - Flag for whether to use multiprocessing in feature generation - """ - - def __init__(self, feature_config: FeatureConfig): - self.transition_scale = 1.0 - self.acoustic_scale = 0.1 - self.self_loop_scale = 0.1 - self.disable_sat = False - self.feature_config = feature_config - self.boost_silence = 1.0 - self.beam = 10 - self.retry_beam = 40 - self.data_directory = None # Gets set later - self.fmllr_update_type = "full" - self.use_mp = True - self.use_fmllr_mp = False - self.debug = False - self.overwrite = False - self.cleanup_textgrids = True - self.initial_fmllr = True - self.iteration = None - - @property - def align_options(self) -> MetaDict: - """Options for use in aligning""" - return { - "transition_scale": self.transition_scale, - "acoustic_scale": self.acoustic_scale, - "self_loop_scale": self.self_loop_scale, - "beam": self.beam, - "retry_beam": self.retry_beam, - "boost_silence": self.boost_silence, - "debug": self.debug, - } - - @property - def fmllr_options(self) -> MetaDict: - """Options for use in calculating fMLLR transforms""" - return { - "fmllr_update_type": self.fmllr_update_type, - } - - def update(self, data: dict) -> None: - """Update configuration""" - for k, v in data.items(): - if k == "use_mp": - self.feature_config.use_mp = v - elif not hasattr(self, k): - continue - setattr(self, k, v) - - def update_from_args(self, args: Namespace): - """Update from command line arguments""" - super(AlignConfig, self).update_from_args(args) - self.feature_config.update_from_args(args) - - def update_from_unknown_args(self, args: Collection[str]): - """Update from unknown command line arguments""" - super(AlignConfig, self).update_from_unknown_args(args) - self.feature_config.update_from_unknown_args(args) - if self.retry_beam <= self.beam: - self.retry_beam = self.beam * 4 - - -def align_yaml_to_config(path: str) -> Tuple[AlignConfig, DictionaryConfig]: - """ - Helper function to load alignment configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.dictionary_config.DictionaryConfig` - Dictionary configuration - """ - dictionary_config = DictionaryConfig() - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - global_params = {} - feature_config = FeatureConfig() - for k, v in data.items(): - if k == "features": - feature_config.update(v) - else: - global_params[k] = v - align_config = AlignConfig(feature_config) - align_config.update(global_params) - dictionary_config.update(global_params) - if align_config.beam >= align_config.retry_beam: - raise ConfigError("Retry beam must be greater than beam.") - return align_config, dictionary_config - - -def load_basic_align() -> Tuple[AlignConfig, DictionaryConfig]: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.AlignConfig` - Default alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - align_config, dictionary_config = align_yaml_to_config( - os.path.join(base_dir, "basic_align.yaml") - ) - return align_config, dictionary_config diff --git a/montreal_forced_aligner/config/base_config.py b/montreal_forced_aligner/config/base_config.py deleted file mode 100644 index 4e132046..00000000 --- a/montreal_forced_aligner/config/base_config.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Class definitions for base configuration""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Collection - -import yaml - -if TYPE_CHECKING: - from argparse import Namespace - - -PARSING_KEYS = [ - "punctuation", - "clitic_markers", - "compound_markers", - "multilingual_ipa", - "strip_diacritics", - "digraphs", -] - -__all__ = ["BaseConfig"] - - -class BaseConfig: - """ - Base configuration class - """ - - def update(self, data: dict) -> None: - """Update configuration parameters""" - for k, v in data.items(): - if not hasattr(self, k): - continue - setattr(self, k, v) - - def update_from_args(self, args: Namespace) -> None: - """Update from command line arguments""" - if args is not None: - try: - self.use_mp = not args.disable_mp - except AttributeError: - pass - try: - self.debug = args.debug - except AttributeError: - pass - try: - self.overwrite = args.overwrite - except AttributeError: - pass - try: - self.cleanup_textgrids = not args.disable_textgrid_cleanup - except AttributeError: - pass - - def params(self) -> dict: - """Configuration parameters""" - return {} - - def update_from_unknown_args(self, args: Collection[str]) -> None: - """Update from unknown command line arguments""" - for i, a in enumerate(args): - if not a.startswith("--"): - continue - name = a.replace("--", "") - try: - original_value = getattr(self, name) - except AttributeError: - continue - if not isinstance(original_value, (bool, int, float, str)): - continue - try: - if isinstance(original_value, bool): - if args[i + 1].lower() == "true": - val = True - elif args[i + 1].lower() == "false": - val = False - elif not original_value: - val = True - else: - continue - else: - val = type(original_value)(args[i + 1]) - except (ValueError): - continue - except (IndexError): - if isinstance(original_value, bool): - if not original_value: - val = True - else: - continue - else: - continue - setattr(self, name, val) - - def save(self, path: str) -> None: - """ - Dump configuration to path - - Parameters - ---------- - path: str - Path to export to - """ - with open(path, "w", encoding="utf8") as f: - yaml.dump(self.params(), f) diff --git a/montreal_forced_aligner/config/basic_align.yaml b/montreal_forced_aligner/config/basic_align.yaml deleted file mode 100644 index 0106c404..00000000 --- a/montreal_forced_aligner/config/basic_align.yaml +++ /dev/null @@ -1,9 +0,0 @@ -beam: 100 -retry_beam: 400 -disable_sat: false - -features: - type: "mfcc" - use_energy: false - frame_shift: 10 - snip_edges: true diff --git a/montreal_forced_aligner/config/basic_classification.yaml b/montreal_forced_aligner/config/basic_classification.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/montreal_forced_aligner/config/basic_segmentation.yaml b/montreal_forced_aligner/config/basic_segmentation.yaml deleted file mode 100644 index 4f30e6b8..00000000 --- a/montreal_forced_aligner/config/basic_segmentation.yaml +++ /dev/null @@ -1,5 +0,0 @@ -energy_threshold: 5.5 -energy_mean_scale: 0.5 -max_segment_length: 30 -min_pause_duration: 0.05 -snap_boundary_threshold: 0.15 diff --git a/montreal_forced_aligner/config/basic_train.yaml b/montreal_forced_aligner/config/basic_train.yaml deleted file mode 100644 index ff4d1536..00000000 --- a/montreal_forced_aligner/config/basic_train.yaml +++ /dev/null @@ -1,48 +0,0 @@ -beam: 10 -retry_beam: 40 - -features: - type: "mfcc" - use_energy: false - frame_shift: 10 - snip_edges: true - -training: - - monophone: - num_iterations: 40 - max_gaussians: 1000 - subset: 2000 - boost_silence: 1.25 - - - triphone: - num_iterations: 35 - num_leaves: 2000 - max_gaussians: 10000 - cluster_threshold: -1 - subset: 5000 - boost_silence: 1.25 - power: 0.25 - - - lda: - num_leaves: 2500 - max_gaussians: 15000 - subset: 10000 - num_iterations: 35 - features: - splice_left_context: 3 - splice_right_context: 3 - - - sat: - num_leaves: 2500 - max_gaussians: 15000 - power: 0.2 - silence_weight: 0.0 - fmllr_update_type: "full" - subset: 10000 - - - sat: - num_leaves: 4200 - max_gaussians: 40000 - power: 0.2 - silence_weight: 0.0 - fmllr_update_type: "full" diff --git a/montreal_forced_aligner/config/basic_train_ivector.yaml b/montreal_forced_aligner/config/basic_train_ivector.yaml deleted file mode 100644 index 01fa57eb..00000000 --- a/montreal_forced_aligner/config/basic_train_ivector.yaml +++ /dev/null @@ -1,12 +0,0 @@ -features: - type: "mfcc" - use_energy: true - frame_shift: 10 - -training: - - ivector: - num_iterations: 10 - gaussian_min_count: 2 - silence_weight: 0.0 - posterior_scale: 0.1 - max_count: 100 diff --git a/montreal_forced_aligner/config/basic_train_lm.yaml b/montreal_forced_aligner/config/basic_train_lm.yaml deleted file mode 100644 index 2987079e..00000000 --- a/montreal_forced_aligner/config/basic_train_lm.yaml +++ /dev/null @@ -1,6 +0,0 @@ -order: 3 -method: kneser_ney -prune: true -count_threshold: 2 -prune_thresh_small: 0.0000003 -prune_thresh_medium: 0.0000001 diff --git a/montreal_forced_aligner/config/basic_transcribe.yaml b/montreal_forced_aligner/config/basic_transcribe.yaml deleted file mode 100644 index 97388371..00000000 --- a/montreal_forced_aligner/config/basic_transcribe.yaml +++ /dev/null @@ -1,9 +0,0 @@ -beam: 13 -max_active: 7000 -lattice_beam: 6 -acoustic_scale: 0.083333 -silence_weight: 0.01 -fmllr: true -first_beam: 10.0 # Beam used in initial, speaker-indep. pass -first_max_active: 2000 # max-active used in initial pass. -fmllr_update_type: full diff --git a/montreal_forced_aligner/config/command_config.py b/montreal_forced_aligner/config/command_config.py deleted file mode 100644 index 56b750c5..00000000 --- a/montreal_forced_aligner/config/command_config.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Class definitions for configuring commands""" -from __future__ import annotations - -import os -from typing import Any - -import yaml - -__all__ = ["CommandConfig", "load_command_configuration"] - - -class CommandConfig(object): - """ - Configuration for running commands - - """ - - def __init__(self, data: dict): - self.data = data - - def __getitem__(self, item: str) -> Any: - """Get key""" - if item not in self.data: - return None - return self.data[item] - - def __setitem__(self, key: str, value: Any) -> None: - """Set key""" - self.data[key] = value - - def update(self, new_data: dict) -> None: - """Update configuration""" - self.data.update(new_data) - - def save(self, conf_path: str) -> None: - """Export to path""" - with open(conf_path, "w") as f: - yaml.dump(self.data, f) - - -def load_command_configuration(conf_path: str, default: dict) -> CommandConfig: - """ - Load a previous run of MFA in a temporary directory - - Parameters - ---------- - conf_path: str - Path to saved configuration - default: dict - Extra parameters to set on load - - Returns - ------- - :class:`~montreal_forced_aligner.config.command_config.CommandConfig` - Command configuration - """ - if os.path.exists(conf_path): - with open(conf_path, "r") as f: - conf = yaml.load(f, Loader=yaml.SafeLoader) - config = CommandConfig(conf) - else: - config = CommandConfig(default) - return config diff --git a/montreal_forced_aligner/config/dictionary_config.py b/montreal_forced_aligner/config/dictionary_config.py deleted file mode 100644 index deb03816..00000000 --- a/montreal_forced_aligner/config/dictionary_config.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Class definitions for configuring pronunciation dictionaries""" -from __future__ import annotations - -import re -from typing import Collection, Dict, List, Optional, Set, Tuple, Union - -from .base_config import BaseConfig - -DEFAULT_PUNCTUATION = list(r'、。।,@<>"(),.:;¿?¡!\\&%#*~【】,…‥「」『』〝〟″⟨⟩♪・‹›«»~′$+=‘') - -DEFAULT_CLITIC_MARKERS = list("'’") -DEFAULT_COMPOUND_MARKERS = list("-/") -DEFAULT_STRIP_DIACRITICS = ["ː", "ˑ", "̩", "̆", "̑", "̯", "͡", "‿", "͜"] -DEFAULT_DIGRAPHS = ["[dt][szʒʃʐʑʂɕç]", "[aoɔe][ʊɪ]"] -DEFAULT_BRACKETS = [("[", "]"), ("{", "}"), ("<", ">"), ("(", ")")] - -__all__ = ["DictionaryConfig"] - - -class DictionaryConfig(BaseConfig): - """ - Class for storing configuration information about pronunciation dictionaries - Path to a directory to store files for Kaldi - oov_code : str, optional - What to label words not in the dictionary, defaults to ``''`` - position_dependent_phones : bool, optional - Specifies whether phones should be represented as dependent on their - position in the word (beginning, middle or end), defaults to True - num_sil_states : int, optional - Number of states to use for silence phones, defaults to 5 - num_nonsil_states : int, optional - Number of states to use for non-silence phones, defaults to 3 - shared_silence_phones : bool, optional - Specify whether to share states across all silence phones, defaults - to True - sil_prob : float, optional - Probability of optional silences following words, defaults to 0.5 - word_set : Collection[str], optional - Word set to limit output files - debug: bool, optional - Flag for whether to perform debug steps and prevent intermediate cleanup - logger: :class:`~logging.Logger`, optional - Logger to output information to - punctuation: str, optional - Punctuation to use when parsing text - clitic_markers: str, optional - Clitic markers to use when parsing text - compound_markers: str, optional - Compound markers to use when parsing text - multilingual_ipa: bool, optional - Flag for multilingual IPA mode, defaults to False - strip_diacritics: List[str], optional - Diacritics to strip in multilingual IPA mode - digraphs: List[str], optional - Digraphs to split up in multilingual IPA mode - """ - - topo_template = " {cur_state} {cur_state} {cur_state} 0.75 {next_state} 0.25 " - topo_sil_template = " {cur_state} {cur_state} {transitions} " - topo_transition_template = " {} {}" - positions: List[str] = ["_B", "_E", "_I", "_S"] - - def __init__( - self, - oov_word: str = "", - silence_word: str = "!sil", - nonoptional_silence_phone: str = "sil", - optional_silence_phone: str = "sp", - oov_phone: str = "spn", - other_noise_phone: str = "spn", - position_dependent_phones: bool = True, - num_silence_states: int = 5, - num_non_silence_states: int = 3, - shared_silence_phones: bool = True, - silence_probability: float = 0.5, - debug: bool = False, - punctuation: Optional[Union[str, Collection[str]]] = None, - clitic_markers: Optional[Union[str, Collection[str]]] = None, - compound_markers: Optional[Collection[str]] = None, - multilingual_ipa: bool = False, - strip_diacritics: Optional[Collection[str]] = None, - digraphs: Optional[Collection[str]] = None, - brackets: Optional[Collection[Tuple[str, str]]] = None, - ): - self.strip_diacritics = DEFAULT_STRIP_DIACRITICS - self.digraphs = DEFAULT_DIGRAPHS - self.punctuation = DEFAULT_PUNCTUATION - self.clitic_markers = DEFAULT_CLITIC_MARKERS - self.compound_markers = DEFAULT_COMPOUND_MARKERS - self.brackets = DEFAULT_BRACKETS - if strip_diacritics is not None: - self.strip_diacritics = strip_diacritics - if digraphs is not None: - self.digraphs = digraphs - if punctuation is not None: - self.punctuation = punctuation - if clitic_markers is not None: - self.clitic_markers = clitic_markers - if compound_markers is not None: - self.compound_markers = compound_markers - if brackets is not None: - self.brackets = brackets - - self.multilingual_ipa = multilingual_ipa - self.num_silence_states = num_silence_states - self.num_non_silence_states = num_non_silence_states - self.shared_silence_phones = shared_silence_phones - self.silence_probability = silence_probability - self.oov_word = oov_word - self.silence_word = silence_word - self.position_dependent_phones = position_dependent_phones - self.optional_silence_phone = optional_silence_phone - self.nonoptional_silence_phone = nonoptional_silence_phone - self.oov_phone = oov_phone - self.other_noise_phone = other_noise_phone - self.debug = debug - self.non_silence_phones: Set[str] = set() - self.max_disambiguation_symbol = 0 - self.disambiguation_symbols = set() - self.clitic_set: Set[str] = set() - - @property - def silence_phones(self): - return { - self.oov_phone, - self.optional_silence_phone, - self.nonoptional_silence_phone, - self.other_noise_phone, - } - - @property - def specials_set(self): - return {self.oov_word, self.silence_word, "", "", ""} - - def update(self, data: dict) -> None: - for k, v in data.items(): - if not hasattr(self, k): - continue - if k == "phones": - continue - if k in ["punctuation", "clitic_markers", "compound_markers"]: - if not v: - continue - if "-" in v: - v = "-" + v.replace("-", "") - if "]" in v and r"\]" not in v: - v = v.replace("]", r"\]") - print(k, v) - setattr(self, k, v) - - @property - def phone_mapping(self) -> Dict[str, int]: - phone_mapping = {} - i = 0 - phone_mapping[""] = i - if self.position_dependent_phones: - for p in self.positional_silence_phones: - i += 1 - phone_mapping[p] = i - for p in self.positional_non_silence_phones: - i += 1 - phone_mapping[p] = i - else: - for p in sorted(self.silence_phones): - i += 1 - phone_mapping[p] = i - for p in sorted(self.non_silence_phones): - i += 1 - phone_mapping[p] = i - i = max(phone_mapping.values()) - for x in range(self.max_disambiguation_symbol + 2): - p = f"#{x}" - self.disambiguation_symbols.add(p) - i += 1 - phone_mapping[p] = i - return phone_mapping - - @property - def positional_silence_phones(self) -> List[str]: - """ - List of silence phones with positions - """ - silence_phones = [] - for p in sorted(self.silence_phones): - silence_phones.append(p) - for pos in self.positions: - silence_phones.append(p + pos) - return silence_phones - - @property - def positional_non_silence_phones(self) -> List[str]: - """ - List of non-silence phones with positions - """ - non_silence_phones = [] - for p in sorted(self.non_silence_phones): - for pos in self.positions: - non_silence_phones.append(p + pos) - return non_silence_phones - - @property - def kaldi_silence_phones(self): - if self.position_dependent_phones: - return self.positional_silence_phones - return sorted(self.silence_phones) - - @property - def kaldi_non_silence_phones(self): - if self.position_dependent_phones: - return self.positional_non_silence_phones - return sorted(self.non_silence_phones) - - @property - def optional_silence_csl(self) -> str: - """ - Phone id of the optional silence phone - """ - return str(self.phone_mapping[self.optional_silence_phone]) - - @property - def silence_csl(self) -> str: - """ - A colon-separated list (as a string) of silence phone ids - """ - return ":".join(map(str, (self.phone_mapping[x] for x in self.kaldi_silence_phones))) - - @property - def phones(self) -> set: - """ - The set of all phones (silence and non-silence) - """ - return self.silence_phones | self.non_silence_phones - - def check_bracketed(self, word: str) -> bool: - """ - Checks whether a given string is surrounded by brackets. - - Parameters - ---------- - word : str - Text to check for final brackets - - Returns - ------- - bool - True if the word is fully bracketed, false otherwise - """ - for b in self.brackets: - if word.startswith(b[0]) and word.endswith(b[-1]): - return True - return False - - def sanitize(self, item: str) -> str: - """ - Sanitize an item according to punctuation and clitic markers - - Parameters - ---------- - item: str - Word to sanitize - - Returns - ------- - str - Sanitized form - """ - for c in self.clitic_markers: - item = item.replace(c, self.clitic_markers[0]) - if not item: - return item - if self.check_bracketed(item): - return item - sanitized = re.sub(rf"^[{''.join(self.punctuation)}]+", "", item) - sanitized = re.sub(rf"[{''.join(self.punctuation)}]+$", "", sanitized) - - return sanitized - - def parse_ipa(self, transcription: List[str]) -> Tuple[str, ...]: - """ - Parse a transcription in a multilingual IPA format (strips out diacritics and splits digraphs). - - Parameters - ---------- - transcription: List[str] - Transcription to parse - - Returns - ------- - Tuple[str, ...] - Parsed transcription - """ - new_transcription = [] - for t in transcription: - new_t = t - for d in self.strip_diacritics: - new_t = new_t.replace(d, "") - if "g" in new_t: - new_t = new_t.replace("g", "ɡ") - - found = False - for digraph in self.digraphs: - if re.match(rf"^{digraph}$", new_t): - found = True - if found: - new_transcription.extend(new_t) - continue - new_transcription.append(new_t) - return tuple(new_transcription) diff --git a/montreal_forced_aligner/config/feature_config.py b/montreal_forced_aligner/config/feature_config.py deleted file mode 100644 index db74496a..00000000 --- a/montreal_forced_aligner/config/feature_config.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Classes for configuring feature generation""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Dict, Text, Union - -from ..exceptions import ConfigError -from .base_config import BaseConfig - -if TYPE_CHECKING: - SpeakerCharacterType = Union[str, int] - -__all__ = ["make_safe", "FeatureConfig"] - - -def make_safe(value: Any) -> str: - """ - Transform an arbitrary value into a string - - Parameters - ---------- - value: Any - Value to make safe - - Returns - ------- - str - Safe value - """ - if isinstance(value, bool): - return str(value).lower() - return str(value) - - -class FeatureConfig(BaseConfig): - """ - Class to store configuration information about MFCC generation - - Attributes - ---------- - directory : str - Path of the directory to store outputs - type : str - Feature type, defaults to "mfcc" - deltas : bool - Flag for whether deltas from previous frames are included in the features, defaults to True - lda : bool - Flag for whether LDA is run on the features, requires an lda.mat to generate, defaults to False - fmllr : bool - Flag for whether speaker adaptation should be run, defaults to False - use_energy : bool - Flag for whether first coefficient should be used, defaults to False - frame_shift : int - number of milliseconds between frames, defaults to 10 - snip_edges : bool - Flag for enabling Kaldi's snip edges, should be better time precision - pitch : bool - Flag for including pitch in features, currently nonfunctional, defaults to False - low_frequency : int - Frequency floor - high_frequency : int - Frequency ceiling - sample_frequency : int - Sampling frequency - allow_downsample : bool - Flag for whether to allow downsampling, default is True - allow_upsample : bool - Flag for whether to allow upsampling, default is True - splice_left_context : int or None - Number of frames to splice on the left for calculating LDA - splice_right_context : int or None - Number of frames to splice on the right for calculating LDA - use_mp : bool - Flag for using multiprocessing, defaults to True - """ - - deprecated_flags = {"lda", "deltas"} - - def __init__(self): - self.type = "mfcc" - self.deltas = True - self.fmllr = False - self.lda = False - self.use_energy = False - self.frame_shift = 10 - self.snip_edges = True - self.pitch = False - self.low_frequency = 20 - self.high_frequency = 7800 - self.sample_frequency = 16000 - self.allow_downsample = True - self.allow_upsample = True - self.splice_left_context = 3 - self.splice_right_context = 3 - self.use_mp = True - - def params(self) -> Dict[Text, Any]: - """Parameters for feature generation""" - return { - "type": self.type, - "use_energy": self.use_energy, - "frame_shift": self.frame_shift, - "snip_edges": self.snip_edges, - "low_frequency": self.low_frequency, - "high_frequency": self.high_frequency, - "sample_frequency": self.sample_frequency, - "allow_downsample": self.allow_downsample, - "allow_upsample": self.allow_upsample, - "pitch": self.pitch, - "fmllr": self.fmllr, - "splice_left_context": self.splice_left_context, - "splice_right_context": self.splice_right_context, - } - - @property - def mfcc_options(self) -> Dict[Text, Any]: - """Parameters to use in computing MFCC features.""" - return { - "use-energy": self.use_energy, - "frame-shift": self.frame_shift, - "low-freq": self.low_frequency, - "high-freq": self.high_frequency, - "sample-frequency": self.sample_frequency, - "allow-downsample": self.allow_downsample, - "allow-upsample": self.allow_upsample, - "snip-edges": self.snip_edges, - } - - def update(self, data: Dict[str, Any]) -> None: - """ - Update configuration with new data - - Parameters - ---------- - data: Dict[str, Any] - New data - """ - for k, v in data.items(): - if k in self.deprecated_flags: - continue - if not hasattr(self, k): - raise ConfigError("No field found for key {}".format(k)) - setattr(self, k, v) - - @property - def feature_id(self) -> str: - """Deprecated feature ID""" - return "feats" diff --git a/montreal_forced_aligner/config/g2p_config.py b/montreal_forced_aligner/config/g2p_config.py deleted file mode 100644 index fe4d95ca..00000000 --- a/montreal_forced_aligner/config/g2p_config.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Class definitions for configuring G2P generation""" -from __future__ import annotations - -from typing import Tuple - -import yaml - -from .base_config import BaseConfig -from .dictionary_config import DictionaryConfig - -__all__ = ["G2PConfig", "g2p_yaml_to_config", "load_basic_g2p_config"] - - -class G2PConfig(BaseConfig): - """ - Configuration class for generating pronunciations - - """ - - def __init__(self): - self.num_pronunciations = 1 - self.use_mp = True - - def update(self, data: dict) -> None: - """Update configuration""" - for k, v in data.items(): - if k in ["punctuation", "clitic_markers", "compound_markers"]: - if not v: - continue - if "-" in v: - v = "-" + v.replace("-", "") - if "]" in v and r"\]" not in v: - v = v.replace("]", r"\]") - elif not hasattr(self, k): - continue - setattr(self, k, v) - - -def g2p_yaml_to_config(path: str) -> Tuple[G2PConfig, DictionaryConfig]: - """ - Helper function to load G2P configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.G2PConfig` - G2P configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - dictionary_config = DictionaryConfig() - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - global_params = {} - for k, v in data.items(): - global_params[k] = v - g2p_config = G2PConfig() - g2p_config.update(global_params) - dictionary_config.update(global_params) - return g2p_config, dictionary_config - - -def load_basic_g2p_config() -> Tuple[G2PConfig, DictionaryConfig]: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.G2PConfig` - Default G2P configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - return G2PConfig(), DictionaryConfig() diff --git a/montreal_forced_aligner/config/segmentation_config.py b/montreal_forced_aligner/config/segmentation_config.py deleted file mode 100644 index 698abc4a..00000000 --- a/montreal_forced_aligner/config/segmentation_config.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Class definitions for configuring sound file segmentation""" -from __future__ import annotations - -import os - -import yaml - -from .base_config import BaseConfig -from .feature_config import FeatureConfig - -__all__ = ["SegmentationConfig", "segmentation_yaml_to_config", "load_basic_segmentation"] - - -class SegmentationConfig(BaseConfig): - """ - Class for storing segmentation configuration - """ - - def __init__(self, feature_config): - self.use_mp = True - self.energy_threshold = 5.5 - self.energy_mean_scale = 0.5 - self.max_segment_length = 30 - self.min_pause_duration = 0.05 - self.snap_boundary_threshold = 0.15 - self.feature_config = feature_config - self.feature_config.use_energy = True - self.overwrite = True - - def update(self, data: dict) -> None: - """Update configuration parameters""" - for k, v in data.items(): - if k == "use_mp": - self.feature_config.use_mp = v - if not hasattr(self, k): - continue - setattr(self, k, v) - - @property - def segmentation_options(self): - """Options for segmentation""" - return { - "max_segment_length": self.max_segment_length, - "min_pause_duration": self.min_pause_duration, - "snap_boundary_threshold": self.snap_boundary_threshold, - "frame_shift": round(self.feature_config.frame_shift / 1000, 2), - } - - -def segmentation_yaml_to_config(path: str) -> SegmentationConfig: - """ - Helper function to load segmentation configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.segmentation_config.SegmentationConfig` - Segmentation configuration - """ - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - global_params = {} - feature_config = FeatureConfig() - for k, v in data.items(): - if k == "features": - feature_config.update(v) - else: - global_params[k] = v - segmentation_config = SegmentationConfig(feature_config) - segmentation_config.update(global_params) - return segmentation_config - - -def load_basic_segmentation() -> SegmentationConfig: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.segmentation_config.SegmentationConfig` - Default segmentation configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - segmentation_config = segmentation_yaml_to_config( - os.path.join(base_dir, "basic_segmentation.yaml") - ) - return segmentation_config diff --git a/montreal_forced_aligner/config/speaker_classification_config.py b/montreal_forced_aligner/config/speaker_classification_config.py deleted file mode 100644 index 4014734c..00000000 --- a/montreal_forced_aligner/config/speaker_classification_config.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Class definitions for configuring speaker classification""" -from __future__ import annotations - -import os - -import yaml - -from .base_config import BaseConfig - -__all__ = [ - "SpeakerClassificationConfig", - "classification_yaml_to_config", - "load_basic_classification", -] - - -class SpeakerClassificationConfig(BaseConfig): - """ - Configuration class to store parameters for speaker classification - """ - - def __init__(self): - self.use_mp = True - self.pca_dimension = -1 - self.target_energy = 0.1 - self.cluster_threshold = 0.5 - self.max_speaker_fraction = 1.0 - self.first_pass_max_utterances = 32767 - self.rttm_channel = 0 - self.read_costs = False - self.overwrite = False - - -def classification_yaml_to_config(path: str) -> SpeakerClassificationConfig: - """ - Helper function to load speaker classification configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.SpeakerClassificationConfig` - Speaker classification configuration - """ - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - classification_config = SpeakerClassificationConfig() - if data: - classification_config.update(data) - return classification_config - - -def load_basic_classification() -> SpeakerClassificationConfig: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.SpeakerClassificationConfig` - Default speaker classification configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - classification_config = classification_yaml_to_config( - os.path.join(base_dir, "basic_classification.yaml") - ) - return classification_config diff --git a/montreal_forced_aligner/config/test_config.yaml b/montreal_forced_aligner/config/test_config.yaml deleted file mode 100644 index e22e9ae0..00000000 --- a/montreal_forced_aligner/config/test_config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -beam: 10 -retry_beam: 40 - -features: - type: "mfcc" - use_energy: false - frame_shift: 10 - pitch: false - -training: - - monophone: - num_iterations: 40 - max_gaussians: 1000 - subset: 10000 diff --git a/montreal_forced_aligner/config/train_config.py b/montreal_forced_aligner/config/train_config.py deleted file mode 100644 index 9af857f1..00000000 --- a/montreal_forced_aligner/config/train_config.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Class definitions for configuring acoustic model training""" -from __future__ import annotations - -import os -from collections import Counter -from typing import Iterator, List, Tuple - -import yaml - -from ..exceptions import ConfigError -from ..trainers import ( - BaseTrainer, - IvectorExtractorTrainer, - LdaTrainer, - MonophoneTrainer, - SatTrainer, - TriphoneTrainer, -) -from .align_config import AlignConfig -from .base_config import BaseConfig -from .dictionary_config import DictionaryConfig -from .feature_config import FeatureConfig - -__all__ = [ - "TrainingConfig", - "train_yaml_to_config", - "load_basic_train", - "load_basic_train_ivector", - "load_test_config", - "load_sat_adapt", - "load_no_sat_adapt", -] - - -class TrainingConfig(BaseConfig): - """ - Configuration class for storing parameters and trainers for training acoustic models - """ - - def __init__(self, training_configs): - self.training_configs = training_configs - counts = Counter([x.train_type for x in self.training_configs]) - self.training_identifiers = [] - curs = {x.train_type: 1 for x in self.training_configs} - for t in training_configs: - i = t.train_type - if counts[t.train_type] != 1: - i += str(curs[t.train_type]) - curs[t.train_type] += 1 - self.training_identifiers.append(i) - - def update_from_align(self, align_config: AlignConfig) -> None: - """Update parameters from an AlignConfig""" - for tc in self.training_configs: - tc.overwrite = align_config.overwrite - tc.cleanup_textgrids = align_config.cleanup_textgrids - - def update(self, data: dict) -> None: - """Update parameters""" - for k, v in data.items(): - if not hasattr(self, k): - continue - setattr(self, k, v) - for trainer in self.values(): - trainer.update(data) - - def keys(self) -> List: - """List of training identifiers""" - return self.training_identifiers - - def values(self) -> List[BaseTrainer]: - """List of trainers""" - return self.training_configs - - def items(self) -> Iterator: - """Iterator over training identifiers and trainers""" - return zip(self.training_identifiers, self.training_configs) - - def __getitem__(self, item: str) -> BaseTrainer: - """Get trainer based on identifier""" - if item not in self.training_identifiers: - raise KeyError(f"{item} not a valid training identifier") - return self.training_configs[self.training_identifiers.index(item)] - - @property - def uses_sat(self) -> bool: - """Flag for whether a trainer uses speaker adaptation""" - for k in self.keys(): - if k.startswith("sat"): - return True - return False - - -def train_yaml_to_config( - path: str, require_mono: bool = True -) -> Tuple[TrainingConfig, AlignConfig, DictionaryConfig]: - """ - Helper function to load acoustic model training configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainingConfig` - Training configuration - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - dictionary_config = DictionaryConfig() - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - global_params = {} - training = [] - training_params = [] - global_feature_params = {} - for k, v in data.items(): - if k == "training": - for t in v: - for k2, v2 in t.items(): - feature_config = FeatureConfig() - if k2 == "monophone": - training.append(MonophoneTrainer(feature_config)) - elif k2 == "triphone": - training.append(TriphoneTrainer(feature_config)) - elif k2 == "lda": - training.append(LdaTrainer(feature_config)) - elif k2 == "sat": - training.append(SatTrainer(feature_config)) - elif k2 == "ivector": - training.append(IvectorExtractorTrainer(feature_config)) - training_params.append(v2) - elif k == "features": - global_feature_params.update(v) - else: - global_params[k] = v - feature_config = FeatureConfig() - feature_config.update(global_feature_params) - align_config = AlignConfig(feature_config) - align_config.update(global_params) - dictionary_config.update(global_params) - training_config = None - if training: - for i, t in enumerate(training): - if i == 0 and require_mono and t.train_type not in ["mono", "ivector"]: - raise ConfigError("The first round of training must be monophone.") - t.update(global_params) - t.update(training_params[i]) - t.feature_config.update(global_feature_params) - training_config = TrainingConfig(training) - align_config.feature_config.fmllr = training_config.uses_sat - if align_config.beam >= align_config.retry_beam: - raise ConfigError("Retry beam must be greater than beam.") - return training_config, align_config, dictionary_config - - -def load_basic_train() -> Tuple[TrainingConfig, AlignConfig, DictionaryConfig]: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainingConfig` - Training configuration - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - training_config, align_config, dictionary_config = train_yaml_to_config( - os.path.join(base_dir, "basic_train.yaml") - ) - return training_config, align_config, dictionary_config - - -def load_sat_adapt() -> Tuple[TrainingConfig, AlignConfig, DictionaryConfig]: - """ - Helper function to load the default speaker adaptation parameters for adapting an acoustic model to new data - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainingConfig` - Training configuration - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - training_config, align_config, dictionary_config = train_yaml_to_config( - os.path.join(base_dir, "adapt_sat.yaml"), require_mono=False - ) - training_config.training_configs[0].fmllr_iterations = range( - 0, training_config.training_configs[0].num_iterations - ) - training_config.training_configs[0].realignment_iterations = range( - 0, training_config.training_configs[0].num_iterations - ) - return training_config, align_config, dictionary_config - - -def load_no_sat_adapt() -> Tuple[TrainingConfig, AlignConfig, DictionaryConfig]: - """ - Helper function to load the default parameters for adapting an acoustic model to new data without speaker adaptation - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainingConfig` - Training configuration - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - training_config, align_config, dictionary_config = train_yaml_to_config( - os.path.join(base_dir, "adapt_nosat.yaml"), require_mono=False - ) - training_config.training_configs[0].realignment_iterations = range( - 0, training_config.training_configs[0].num_iterations - ) - return training_config, align_config, dictionary_config - - -def load_basic_train_ivector() -> Tuple[TrainingConfig, AlignConfig, DictionaryConfig]: - """ - Helper function to load the default parameters for training ivector extractors - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainingConfig` - Training configuration - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - training_config, align_config, dictionary_config = train_yaml_to_config( - os.path.join(base_dir, "basic_train_ivector.yaml") - ) - return training_config, align_config, dictionary_config - - -def load_test_config() -> Tuple[TrainingConfig, AlignConfig, DictionaryConfig]: - """ - Helper function to load the default parameters for validating corpora - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainingConfig` - Training configuration - :class:`~montreal_forced_aligner.config.AlignConfig` - Alignment configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - training_config, align_config, dictionary_config = train_yaml_to_config( - os.path.join(base_dir, "test_config.yaml") - ) - return training_config, align_config, dictionary_config diff --git a/montreal_forced_aligner/config/train_g2p_config.py b/montreal_forced_aligner/config/train_g2p_config.py deleted file mode 100644 index 0e8a366e..00000000 --- a/montreal_forced_aligner/config/train_g2p_config.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Class definitions for configuring G2P model training""" -from __future__ import annotations - -from typing import Tuple - -import yaml - -from .base_config import BaseConfig -from .dictionary_config import DictionaryConfig - -__all__ = ["TrainG2PConfig", "train_g2p_yaml_to_config", "load_basic_train_g2p_config"] - - -class TrainG2PConfig(BaseConfig): - """ - Configuration class for training G2P models - """ - - def __init__(self): - self.num_pronunciations = 1 - self.order = 7 - self.random_starts = 25 - self.seed = 1917 - self.delta = 1 / 1024 - self.lr = 1.0 - self.batch_size = 200 - self.max_iterations = 10 - self.smoothing_method = "kneser_ney" - self.pruning_method = "relative_entropy" - self.model_size = 1000000 - self.use_mp = True - - -def train_g2p_yaml_to_config(path: str) -> Tuple[TrainG2PConfig, DictionaryConfig]: - """ - Helper function to load G2P training configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainG2PConfig` - G2P training configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - dictionary_config = DictionaryConfig() - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - global_params = {} - for k, v in data.items(): - global_params[k] = v - g2p_config = TrainG2PConfig() - g2p_config.update(global_params) - dictionary_config.update(global_params) - return g2p_config, dictionary_config - - -def load_basic_train_g2p_config() -> Tuple[TrainG2PConfig, DictionaryConfig]: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainG2PConfig` - Default G2P training configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - return TrainG2PConfig(), DictionaryConfig() diff --git a/montreal_forced_aligner/config/train_lm_config.py b/montreal_forced_aligner/config/train_lm_config.py deleted file mode 100644 index 802b706a..00000000 --- a/montreal_forced_aligner/config/train_lm_config.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Class definitions for configuring language model training""" -from __future__ import annotations - -import os -from typing import Tuple - -import yaml - -from .base_config import BaseConfig -from .dictionary_config import DictionaryConfig - -__all__ = ["TrainLMConfig", "train_lm_yaml_to_config", "load_basic_train_lm"] - - -class TrainLMConfig(BaseConfig): - """ - Class for storing configuration information for training language models - - Attributes - ---------- - order: int - method: str - prune: bool - count_threshold: int - prune_thresh_small: float - prune_thresh_medium: float - use_mp: bool - """ - - def __init__(self): - self.order = 3 - self.method = "kneser_ney" - self.prune = False - self.count_threshold = 1 - self.prune_thresh_small = 0.0000003 - self.prune_thresh_medium = 0.0000001 - self.use_mp = True - - -def train_lm_yaml_to_config(path: str) -> Tuple[TrainLMConfig, DictionaryConfig]: - """ - Helper function to load language model training configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainLMConfig` - Language model training configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - dictionary_config = DictionaryConfig() - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - config = TrainLMConfig() - config.update(data) - dictionary_config.update(data) - return config, dictionary_config - - -def load_basic_train_lm() -> Tuple[TrainLMConfig, DictionaryConfig]: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.TrainLMConfig` - Default language model training configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - training_config, dictionary_config = train_lm_yaml_to_config( - os.path.join(base_dir, "basic_train_lm.yaml") - ) - return training_config, dictionary_config diff --git a/montreal_forced_aligner/config/transcribe_config.py b/montreal_forced_aligner/config/transcribe_config.py deleted file mode 100644 index f41ae700..00000000 --- a/montreal_forced_aligner/config/transcribe_config.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Class definitions for configuring transcription""" -from __future__ import annotations - -import os -from typing import TYPE_CHECKING, Tuple - -import yaml - -from .base_config import BaseConfig -from .dictionary_config import DictionaryConfig -from .feature_config import FeatureConfig - -if TYPE_CHECKING: - from ..abc import MetaDict - -__all__ = ["TranscribeConfig", "transcribe_yaml_to_config", "load_basic_transcribe"] - - -class TranscribeConfig(BaseConfig): - """ - Class for storing metadata to configure transcription - - Parameters - ---------- - feature_config: :class:`~montreal_forced_aligner.config.FeatureConfig` - Feature configuration to use in transcription - - Attributes - ---------- - transition_scale: float - acoustic_scale: float - self_loop_scale: float - silence_weight: float - beam: int - max_active: int - fmllr: bool - fmllr_update_type: str - lattice_beam: int - first_beam: int, optional - """ - - def __init__(self, feature_config: FeatureConfig): - self.transition_scale = 1.0 - self.acoustic_scale = 0.083333 - self.self_loop_scale = 0.1 - self.feature_config = feature_config - self.silence_weight = 0.01 - self.beam = 10 - self.max_active = 7000 - self.fmllr = True - self.fmllr_update_type = "full" - self.lattice_beam = 6 - self.first_beam = None - self.first_max_active = 2000 - self.language_model_weight = 10 - self.word_insertion_penalty = 0.5 - self.data_directory = None # Gets set later - self.use_mp = True - self.use_fmllr_mp = False - self.ignore_speakers = False - self.overwrite = False - - def params(self) -> MetaDict: - """Metadata parameters for the configuration""" - return { - "transition_scale": self.transition_scale, - "acoustic_scale": self.acoustic_scale, - "self_loop_scale": self.self_loop_scale, - "silence_weight": self.silence_weight, - "beam": self.beam, - "max_active": self.max_active, - "fmllr": self.fmllr, - "fmllr_update_type": self.fmllr_update_type, - "lattice_beam": self.lattice_beam, - "first_beam": self.first_beam, - "first_max_active": self.first_max_active, - "language_model_weight": self.language_model_weight, - "word_insertion_penalty": self.word_insertion_penalty, - "use_mp": self.use_mp, - } - - @property - def decode_options(self) -> MetaDict: - """Options needed for decoding""" - return { - "fmllr": self.fmllr, - "ignore_speakers": self.ignore_speakers, - "first_beam": self.first_beam, - "beam": self.beam, - "first_max_active": self.first_max_active, - "max_active": self.max_active, - "lattice_beam": self.lattice_beam, - "acoustic_scale": self.acoustic_scale, - } - - @property - def score_options(self) -> MetaDict: - """Options needed for scoring lattices""" - return { - "language_model_weight": self.language_model_weight, - "word_insertion_penalty": self.word_insertion_penalty, - } - - @property - def fmllr_options(self) -> MetaDict: - """Options needed for calculating fMLLR transformations""" - return { - "fmllr_update_type": self.fmllr_update_type, - "acoustic_scale": self.acoustic_scale, - "silence_weight": self.silence_weight, - "lattice_beam": self.lattice_beam, - } - - @property - def lm_rescore_options(self) -> MetaDict: - """Options needed for rescoring the language model""" - return { - "acoustic_scale": self.acoustic_scale, - } - - def update(self, data: dict) -> None: - """Update configuration with new parameters""" - for k, v in data.items(): - if k == "use_mp": - self.feature_config.use_mp = v - if not hasattr(self, k): - continue - setattr(self, k, v) - - -def transcribe_yaml_to_config(path: str) -> Tuple[TranscribeConfig, DictionaryConfig]: - """ - Helper function to load transcription configurations - - Parameters - ---------- - path: str - Path to yaml file - - Returns - ------- - :class:`~montreal_forced_aligner.config.TranscribeConfig` - Transcription configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - dictionary_config = DictionaryConfig() - with open(path, "r", encoding="utf8") as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - global_params = {} - feature_config = FeatureConfig() - for k, v in data.items(): - if k == "features": - feature_config.update(v) - else: - global_params[k] = v - config = TranscribeConfig(feature_config) - config.update(global_params) - dictionary_config.update(global_params) - return config, dictionary_config - - -def load_basic_transcribe() -> Tuple[TranscribeConfig, DictionaryConfig]: - """ - Helper function to load the default parameters - - Returns - ------- - :class:`~montreal_forced_aligner.config.TranscribeConfig` - Default transcription configuration - :class:`~montreal_forced_aligner.config.DictionaryConfig` - Dictionary configuration - """ - base_dir = os.path.dirname(os.path.abspath(__file__)) - config, dictionary_config = transcribe_yaml_to_config( - os.path.join(base_dir, "basic_transcribe.yaml") - ) - return config, dictionary_config diff --git a/montreal_forced_aligner/corpus/__init__.py b/montreal_forced_aligner/corpus/__init__.py index 50cb731c..0c28a634 100644 --- a/montreal_forced_aligner/corpus/__init__.py +++ b/montreal_forced_aligner/corpus/__init__.py @@ -6,12 +6,36 @@ """ from __future__ import annotations -from .base import Corpus # noqa -from .classes import File, Speaker, Utterance +from montreal_forced_aligner.corpus.acoustic_corpus import ( + AcousticCorpus, + AcousticCorpusMixin, + AcousticCorpusPronunciationMixin, +) +from montreal_forced_aligner.corpus.base import CorpusMixin +from montreal_forced_aligner.corpus.classes import File, Speaker, Utterance +from montreal_forced_aligner.corpus.text_corpus import ( + DictionaryTextCorpusMixin, + TextCorpus, + TextCorpusMixin, +) -__all__ = ["Corpus", "Speaker", "Utterance", "File", "base", "helper", "classes"] - -Corpus.__module__ = "montreal_forced_aligner.corpus" -Speaker.__module__ = "montreal_forced_aligner.corpus" -Utterance.__module__ = "montreal_forced_aligner.corpus" -File.__module__ = "montreal_forced_aligner.corpus" +__all__ = [ + "base", + "helper", + "classes", + "File", + "Speaker", + "Utterance", + "features", + "multiprocessing", + "CorpusMixin", + "ivector_corpus", + "acoustic_corpus", + "AcousticCorpus", + "AcousticCorpusMixin", + "AcousticCorpusPronunciationMixin", + "text_corpus", + "TextCorpus", + "TextCorpusMixin", + "DictionaryTextCorpusMixin", +] diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py new file mode 100644 index 00000000..e86f1275 --- /dev/null +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -0,0 +1,893 @@ +"""Class definitions for corpora""" +from __future__ import annotations + +import multiprocessing as mp +import os +import shutil +import subprocess +import sys +import time +from abc import ABCMeta +from queue import Empty +from typing import Optional + +from montreal_forced_aligner.abc import MfaWorker, TemporaryDirectoryMixin +from montreal_forced_aligner.corpus.base import CorpusMixin +from montreal_forced_aligner.corpus.classes import parse_file +from montreal_forced_aligner.corpus.features import ( + CalcFmllrArguments, + FeatureConfigMixin, + MfccArguments, + VadArguments, + calc_fmllr_func, + compute_vad_func, + mfcc_func, +) +from montreal_forced_aligner.corpus.helper import find_exts +from montreal_forced_aligner.corpus.multiprocessing import CorpusProcessWorker +from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin +from montreal_forced_aligner.exceptions import TextGridParseError, TextParseError +from montreal_forced_aligner.helper import load_scp +from montreal_forced_aligner.utils import Stopped, run_mp, run_non_mp, thirdparty_binary + + +class AcousticCorpusMixin(CorpusMixin, FeatureConfigMixin, metaclass=ABCMeta): + """ + Mixin class for acoustic corpora + + Parameters + ---------- + audio_directory: str + Extra directory to look for audio files + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.base.CorpusMixin` + For corpus parsing parameters + :class:`~montreal_forced_aligner.corpus.features.FeatureConfigMixin` + For feature generation parameters + + Attributes + ---------- + sound_file_errors: list[str] + List of sound files with errors in loading + transcriptions_without_wavs: list[str] + List of text files without sound files + no_transcription_files: list[str] + List of sound files without transcription files + stopped: Stopped + Stop check for loading the corpus + """ + + def __init__(self, audio_directory: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.audio_directory = audio_directory + self.sound_file_errors = [] + self.transcriptions_without_wavs = [] + self.no_transcription_files = [] + self.stopped = Stopped() + + def load_corpus(self) -> None: + """ + Load the corpus + """ + self._load_corpus() + + self.initialize_jobs() + self.write_corpus_information() + self.create_corpus_split() + self.generate_features() + + def generate_features(self, overwrite: bool = False, compute_cmvn: bool = True) -> None: + """ + Generate features for the corpus + + Parameters + ---------- + overwrite: bool + Flag for whether to ignore existing files, defaults to False + compute_cmvn: bool + Flag for whether to compute CMVN, defaults to True + """ + if not overwrite and os.path.exists( + os.path.join(self.corpus_output_directory, "feats.scp") + ): + return + self.log_info(f"Generating base features ({self.feature_type})...") + if self.feature_type == "mfcc": + self.mfcc() + self.combine_feats() + if compute_cmvn: + self.log_info("Calculating CMVN...") + self.calc_cmvn() + self.write_corpus_information() + self.create_corpus_split() + + def write_corpus_information(self) -> None: + """ + Output information to the temporary directory for later loading + """ + super().write_corpus_information() + self._write_feats() + + def construct_base_feature_string(self, all_feats: bool = False) -> str: + """ + Construct the base feature string independent of job name + + Used in initialization of MonophoneTrainer (to get dimension size) and IvectorTrainer (uses all feats) + + Parameters + ---------- + all_feats: bool + Flag for whether all features across all jobs should be taken into account + + Returns + ------- + str + Base feature string + """ + j = self.jobs[0] + if all_feats: + feat_path = os.path.join(self.base_data_directory, "feats.scp") + utt2spk_path = os.path.join(self.base_data_directory, "utt2spk.scp") + cmvn_path = os.path.join(self.base_data_directory, "cmvn.scp") + feats = f"ark,s,cs:apply-cmvn --utt2spk=ark:{utt2spk_path} scp:{cmvn_path} scp:{feat_path} ark:- |" + feats += " add-deltas ark:- ark:- |" + return feats + utt2spks = j.construct_path_dictionary(self.data_directory, "utt2spk", "scp") + cmvns = j.construct_path_dictionary(self.data_directory, "cmvn", "scp") + features = j.construct_path_dictionary(self.data_directory, "feats", "scp") + for dict_name in j.current_dictionary_names: + feat_path = features[dict_name] + cmvn_path = cmvns[dict_name] + utt2spk_path = utt2spks[dict_name] + feats = f"ark,s,cs:apply-cmvn --utt2spk=ark:{utt2spk_path} scp:{cmvn_path} scp:{feat_path} ark:- |" + if self.uses_deltas: + feats += " add-deltas ark:- ark:- |" + + return feats + + def construct_feature_proc_strings( + self, + speaker_independent: bool = False, + ) -> list[dict[str, str]]: + """ + Constructs a feature processing string to supply to Kaldi binaries, taking into account corpus features and the + current working directory of the aligner (whether fMLLR or LDA transforms should be used, etc). + + Parameters + ---------- + speaker_independent: bool + Flag for whether features should be speaker-independent regardless of the presence of fMLLR transforms + + Returns + ------- + list[dict[str, str]] + Feature strings per job + """ + strings = [] + for j in self.jobs: + lda_mat_path = None + fmllrs = {} + if self.working_directory is not None: + lda_mat_path = os.path.join(self.working_directory, "lda.mat") + if not os.path.exists(lda_mat_path): + lda_mat_path = None + + fmllrs = j.construct_path_dictionary(self.working_directory, "trans", "ark") + utt2spks = j.construct_path_dictionary(self.data_directory, "utt2spk", "scp") + cmvns = j.construct_path_dictionary(self.data_directory, "cmvn", "scp") + features = j.construct_path_dictionary(self.data_directory, "feats", "scp") + vads = j.construct_path_dictionary(self.data_directory, "vad", "scp") + feat_strings = {} + for dict_name in j.current_dictionary_names: + feat_path = features[dict_name] + cmvn_path = cmvns[dict_name] + utt2spk_path = utt2spks[dict_name] + fmllr_trans_path = None + try: + fmllr_trans_path = fmllrs[dict_name] + if not os.path.exists(fmllr_trans_path): + fmllr_trans_path = None + except KeyError: + pass + vad_path = vads[dict_name] + if self.uses_voiced: + feats = f"ark,s,cs:add-deltas scp:{feat_path} ark:- |" + if self.uses_cmvn: + feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |" + feats += f" select-voiced-frames ark:- scp,s,cs:{vad_path} ark:- |" + elif not os.path.exists(cmvn_path) and self.uses_cmvn: + feats = f"ark,s,cs:add-deltas scp:{feat_path} ark:- |" + if self.uses_cmvn: + feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |" + else: + feats = f"ark,s,cs:apply-cmvn --utt2spk=ark:{utt2spk_path} scp:{cmvn_path} scp:{feat_path} ark:- |" + if lda_mat_path is not None: + if not os.path.exists(lda_mat_path): + raise Exception(f"Could not find {lda_mat_path}") + feats += f" splice-feats --left-context={self.splice_left_context} --right-context={self.splice_right_context} ark:- ark:- |" + feats += f" transform-feats {lda_mat_path} ark:- ark:- |" + elif self.uses_splices: + feats += f" splice-feats --left-context={self.splice_left_context} --right-context={self.splice_right_context} ark:- ark:- |" + elif self.uses_deltas: + feats += " add-deltas ark:- ark:- |" + + if fmllr_trans_path is not None and not ( + self.speaker_independent or speaker_independent + ): + if not os.path.exists(fmllr_trans_path): + raise Exception(f"Could not find {fmllr_trans_path}") + feats += f" transform-feats --utt2spk=ark:{utt2spk_path} ark:{fmllr_trans_path} ark:- ark:- |" + feat_strings[dict_name] = feats + strings.append(feat_strings) + return strings + + def compute_vad_arguments(self) -> list[VadArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.corpus.features.compute_vad_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.corpus.features.VadArguments`] + Arguments for processing + """ + return [ + VadArguments( + os.path.join(self.split_directory, "log", f"compute_vad.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.split_directory, "feats", "scp"), + j.construct_path_dictionary(self.split_directory, "vad", "scp"), + self.vad_options, + ) + for j in self.jobs + ] + + def calc_fmllr_arguments(self) -> list[CalcFmllrArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.corpus.features.calc_fmllr_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.corpus.features.CalcFmllrArguments`] + Arguments for processing + """ + feature_strings = self.construct_feature_proc_strings() + return [ + CalcFmllrArguments( + os.path.join(self.working_log_directory, f"calc_fmllr.{j.name}.log"), + j.current_dictionary_names, + feature_strings[j.name], + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + self.alignment_model_path, + self.model_path, + j.construct_path_dictionary(self.data_directory, "spk2utt", "scp"), + j.construct_path_dictionary(self.working_directory, "trans", "ark"), + self.fmllr_options, + ) + for j in self.jobs + ] + + def mfcc_arguments(self) -> list[MfccArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.corpus.features.mfcc_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.corpus.features.MfccArguments`] + Arguments for processing + """ + return [ + MfccArguments( + os.path.join(self.split_directory, "log", f"make_mfcc.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.split_directory, "feats", "scp"), + j.construct_path_dictionary(self.split_directory, "utterance_lengths", "scp"), + j.construct_path_dictionary(self.split_directory, "segments", "scp"), + j.construct_path_dictionary(self.split_directory, "wav", "scp"), + self.mfcc_options, + ) + for j in self.jobs + ] + + def mfcc(self) -> None: + """ + Multiprocessing function that converts sound files into MFCCs. + + See :kaldi_docs:`feat` for an overview on feature generation in Kaldi. + + See Also + -------- + :func:`~montreal_forced_aligner.corpus.features.mfcc_func` + Multiprocessing helper function for each job + :meth:`.AcousticCorpusMixin.mfcc_arguments` + Job method for generating arguments for helper function + :kaldi_steps:`make_mfcc` + Reference Kaldi script + """ + log_directory = os.path.join(self.split_directory, "log") + os.makedirs(log_directory, exist_ok=True) + + jobs = self.mfcc_arguments() + if self.use_mp: + run_mp(mfcc_func, jobs, log_directory) + else: + run_non_mp(mfcc_func, jobs, log_directory) + + def calc_cmvn(self) -> None: + """ + Calculate CMVN statistics for speakers + + See Also + -------- + :kaldi_src:`compute-cmvn-stats` + Relevant Kaldi binary + """ + spk2utt = os.path.join(self.corpus_output_directory, "spk2utt.scp") + feats = os.path.join(self.corpus_output_directory, "feats.scp") + cmvn_directory = os.path.join(self.features_directory, "cmvn") + os.makedirs(cmvn_directory, exist_ok=True) + cmvn_ark = os.path.join(cmvn_directory, "cmvn.ark") + cmvn_scp = os.path.join(cmvn_directory, "cmvn.scp") + log_path = os.path.join(cmvn_directory, "cmvn.log") + with open(log_path, "w") as logf: + subprocess.call( + [ + thirdparty_binary("compute-cmvn-stats"), + f"--spk2utt=ark:{spk2utt}", + f"scp:{feats}", + f"ark,scp:{cmvn_ark},{cmvn_scp}", + ], + stderr=logf, + env=os.environ, + ) + shutil.copy(cmvn_scp, os.path.join(self.corpus_output_directory, "cmvn.scp")) + for s, cmvn in load_scp(cmvn_scp).items(): + self.speakers[s].cmvn = cmvn + self.create_corpus_split() + + def calc_fmllr(self) -> None: + """ + Multiprocessing function that computes speaker adaptation transforms via + Feature space Maximum Likelihood Linear Regression (fMLLR). + + See Also + -------- + :func:`~montreal_forced_aligner.corpus.features.calc_fmllr_func` + Multiprocessing helper function for each job + :meth:`.AcousticCorpusMixin.calc_fmllr_arguments` + Job method for generating arguments for the helper function + :kaldi_steps:`align_fmllr` + Reference Kaldi script + :kaldi_steps:`train_sat` + Reference Kaldi script + """ + begin = time.time() + log_directory = self.working_log_directory + + jobs = self.calc_fmllr_arguments() + if self.use_mp: + run_mp(calc_fmllr_func, jobs, log_directory) + else: + run_non_mp(calc_fmllr_func, jobs, log_directory) + self.speaker_independent = False + self.log_debug(f"Fmllr calculation took {time.time() - begin}") + + def compute_vad(self) -> None: + """ + Compute Voice Activity Detection features over the corpus + + See Also + -------- + :func:`~montreal_forced_aligner.corpus.features.compute_vad_func` + Multiprocessing helper function for each job + :meth:`.AcousticCorpusMixin.compute_vad_arguments` + Job method for generating arguments for helper function + """ + if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")): + self.log_info("VAD already computed, skipping!") + return + self.log_info("Computing VAD...") + log_directory = os.path.join(self.split_directory, "log") + os.makedirs(log_directory, exist_ok=True) + jobs = self.compute_vad_arguments() + if self.use_mp: + run_mp(compute_vad_func, jobs, log_directory) + else: + run_non_mp(compute_vad_func, jobs, log_directory) + + def combine_feats(self) -> None: + """ + Combine feature generation results and store relevant information + """ + split_directory = self.split_directory + ignore_check = [] + for job in self.jobs: + feats_paths = job.construct_path_dictionary(split_directory, "feats", "scp") + lengths_paths = job.construct_path_dictionary( + split_directory, "utterance_lengths", "scp" + ) + for dict_name in job.current_dictionary_names: + path = feats_paths[dict_name] + lengths_path = lengths_paths[dict_name] + if os.path.exists(lengths_path): + with open(lengths_path, "r") as inf: + for line in inf: + line = line.strip() + utt, length = line.split() + length = int(length) + if length < 13: # Minimum length to align one phone plus silence + self.utterances[utt].ignored = True + ignore_check.append(utt) + self.utterances[utt].feature_length = length + with open(path, "r") as inf: + for line in inf: + line = line.strip() + if line == "": + continue + f = line.split(maxsplit=1) + if self.utterances[f[0]].ignored: + continue + self.utterances[f[0]].features = f[1] + for u, utterance in self.utterances.items(): + if utterance.features is None: + utterance.ignored = True + ignore_check.append(u) + if ignore_check: + self.log_warning( + "There were some utterances ignored due to short duration, see the log file for full " + "details or run `mfa validate` on the corpus." + ) + self.log_debug( + f"The following utterances were too short to run alignment: " + f"{' ,'.join(ignore_check)}" + ) + self.write_corpus_information() + + def _write_feats(self): + """Write feats scp file for Kaldi""" + if any(x.features is not None for x in self.utterances.values()): + with open( + os.path.join(self.corpus_output_directory, "feats.scp"), "w", encoding="utf8" + ) as f: + for utterance in self.utterances.values(): + if not utterance.features: + continue + f.write(f"{utterance.name} {utterance.features}\n") + + def get_feat_dim(self) -> int: + """ + Calculate the feature dimension for the corpus + + Returns + ------- + int + Dimension of feature vectors + """ + feature_string = self.construct_base_feature_string() + with open(os.path.join(self.features_log_directory, "feat-to-dim.log"), "w") as log_file: + subset_proc = subprocess.Popen( + [ + thirdparty_binary("subset-feats"), + "--n=1", + feature_string, + "ark:-", + ], + stderr=log_file, + stdout=subprocess.PIPE, + ) + dim_proc = subprocess.Popen( + [thirdparty_binary("feat-to-dim"), "ark:-", "-"], + stdin=subset_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + ) + stdout, stderr = dim_proc.communicate() + feats = stdout.decode("utf8").strip() + return int(feats) + + def _load_corpus_from_source_mp(self) -> None: + """ + Load a corpus using multiprocessing + """ + begin_time = time.time() + manager = mp.Manager() + job_queue = manager.Queue() + return_queue = manager.Queue() + return_dict = manager.dict() + return_dict["sound_file_errors"] = manager.list() + return_dict["decode_error_files"] = manager.list() + return_dict["textgrid_read_errors"] = manager.dict() + finished_adding = Stopped() + procs = [] + for _ in range(self.num_jobs): + p = CorpusProcessWorker( + job_queue, return_dict, return_queue, self.stopped, finished_adding + ) + procs.append(p) + p.start() + try: + + use_audio_directory = False + all_sound_files = {} + if self.audio_directory and os.path.exists(self.audio_directory): + use_audio_directory = True + for root, _, files in os.walk(self.audio_directory, followlinks=True): + ( + identifiers, + wav_files, + lab_files, + textgrid_files, + other_audio_files, + ) = find_exts(files) + wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} + other_audio_files = { + k: os.path.join(root, v) for k, v in other_audio_files.items() + } + all_sound_files.update(other_audio_files) + all_sound_files.update(wav_files) + + for root, _, files in os.walk(self.corpus_directory, followlinks=True): + identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts( + files + ) + relative_path = root.replace(self.corpus_directory, "").lstrip("/").lstrip("\\") + + if self.stopped.stop_check(): + break + if not use_audio_directory: + all_sound_files = {} + wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} + other_audio_files = { + k: os.path.join(root, v) for k, v in other_audio_files.items() + } + all_sound_files.update(other_audio_files) + all_sound_files.update(wav_files) + for file_name in identifiers: + if self.stopped.stop_check(): + break + wav_path = None + transcription_path = None + if file_name in all_sound_files: + wav_path = all_sound_files[file_name] + if file_name in lab_files: + lab_name = lab_files[file_name] + transcription_path = os.path.join(root, lab_name) + + elif file_name in textgrid_files: + tg_name = textgrid_files[file_name] + transcription_path = os.path.join(root, tg_name) + if wav_path is None: + self.transcriptions_without_wavs.append(transcription_path) + continue + if transcription_path is None: + self.no_transcription_files.append(wav_path) + if hasattr(self, "construct_sanitize_function"): + job_queue.put( + ( + file_name, + wav_path, + transcription_path, + relative_path, + self.speaker_characters, + self.construct_sanitize_function(), + ) + ) + else: + job_queue.put( + ( + file_name, + wav_path, + transcription_path, + relative_path, + self.speaker_characters, + None, + ) + ) + + finished_adding.stop() + self.log_debug("Finished adding jobs!") + job_queue.join() + + self.log_debug("Waiting for workers to finish...") + for p in procs: + p.join() + + while True: + try: + file = return_queue.get(timeout=1) + if self.stopped.stop_check(): + continue + except Empty: + break + + self.add_file(file) + + if "error" in return_dict: + raise return_dict["error"][1] + + for k in ["sound_file_errors", "decode_error_files", "textgrid_read_errors"]: + if hasattr(self, k): + if return_dict[k]: + self.log_info( + "There were some issues with files in the corpus. " + "Please look at the log file or run the validator for more information." + ) + self.log_debug(f"{k} showed {len(return_dict[k])} errors:") + if k == "textgrid_read_errors": + getattr(self, k).update(return_dict[k]) + for f, e in return_dict[k].items(): + self.log_debug(f"{f}: {e.error}") + else: + self.log_debug(", ".join(return_dict[k])) + setattr(self, k, return_dict[k]) + + except KeyboardInterrupt: + self.log_info("Detected ctrl-c, please wait a moment while we clean everything up...") + self.stopped.stop() + finished_adding.stop() + job_queue.join() + self.stopped.set_sigint_source() + while True: + try: + _ = return_queue.get(timeout=1) + if self.stopped.stop_check(): + continue + except Empty: + break + finally: + + if self.stopped.stop_check(): + self.log_info(f"Stopped parsing early ({time.time() - begin_time} seconds)") + if self.stopped.source(): + sys.exit(0) + else: + self.log_debug( + f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds" + ) + + def _load_corpus_from_source(self) -> None: + """ + Load a corpus without using multiprocessing + """ + begin_time = time.time() + + all_sound_files = {} + use_audio_directory = False + if self.audio_directory and os.path.exists(self.audio_directory): + use_audio_directory = True + for root, _, files in os.walk(self.audio_directory, followlinks=True): + if self.stopped.stop_check(): + return + identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts( + files + ) + wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} + other_audio_files = { + k: os.path.join(root, v) for k, v in other_audio_files.items() + } + all_sound_files.update(other_audio_files) + all_sound_files.update(wav_files) + + for root, _, files in os.walk(self.corpus_directory, followlinks=True): + identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts(files) + relative_path = root.replace(self.corpus_directory, "").lstrip("/").lstrip("\\") + if self.stopped.stop_check(): + return + if not use_audio_directory: + all_sound_files = {} + wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} + other_audio_files = { + k: os.path.join(root, v) for k, v in other_audio_files.items() + } + all_sound_files.update(other_audio_files) + all_sound_files.update(wav_files) + for file_name in identifiers: + + wav_path = None + transcription_path = None + if file_name in all_sound_files: + wav_path = all_sound_files[file_name] + if file_name in lab_files: + lab_name = lab_files[file_name] + transcription_path = os.path.join(root, lab_name) + elif file_name in textgrid_files: + tg_name = textgrid_files[file_name] + transcription_path = os.path.join(root, tg_name) + if wav_path is None: + self.transcriptions_without_wavs.append(transcription_path) + continue + if transcription_path is None: + self.no_transcription_files.append(wav_path) + + try: + if hasattr(self, "construct_sanitize_function"): + file = parse_file( + file_name, + wav_path, + transcription_path, + relative_path, + self.speaker_characters, + self.construct_sanitize_function(), + ) + else: + file = parse_file( + file_name, + wav_path, + transcription_path, + relative_path, + self.speaker_characters, + None, + ) + self.add_file(file) + except TextParseError as e: + self.decode_error_files.append(e) + except TextGridParseError as e: + self.textgrid_read_errors[e.file_name] = e + if self.decode_error_files or self.textgrid_read_errors: + self.log_info( + "There were some issues with files in the corpus. " + "Please look at the log file or run the validator for more information." + ) + if self.decode_error_files: + self.log_debug( + f"There were {len(self.decode_error_files)} errors decoding text files:" + ) + self.log_debug(", ".join(self.decode_error_files)) + if self.textgrid_read_errors: + self.log_debug( + f"There were {len(self.textgrid_read_errors)} errors decoding reading TextGrid files:" + ) + for f, e in self.textgrid_read_errors.items(): + self.log_debug(f"{f}: {e.error}") + + self.log_debug(f"Parsed corpus directory in {time.time() - begin_time} seconds") + + +class AcousticCorpusPronunciationMixin( + AcousticCorpusMixin, MultispeakerDictionaryMixin, metaclass=ABCMeta +): + """ + Mixin for acoustic corpora with Pronunciation dictionaries + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusMixin` + For corpus parsing parameters + :class:`~montreal_forced_aligner.dictionary.multispeaker.MultispeakerDictionaryMixin` + For dictionary parsing parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_corpus(self) -> None: + """ + Load the corpus + """ + all_begin = time.time() + self.dictionary_setup() + self.log_debug(f"Loaded dictionary in {time.time() - all_begin}") + + begin = time.time() + self._load_corpus() + self.log_debug(f"Loaded corpus in {time.time() - begin}") + + begin = time.time() + self.set_lexicon_word_set(self.corpus_word_set) + self.log_debug(f"Set up lexicon word set in {time.time() - begin}") + + begin = time.time() + self.write_lexicon_information() + self.log_debug(f"Wrote lexicon information in {time.time() - begin}") + + begin = time.time() + for speaker in self.speakers.values(): + speaker.set_dictionary(self.get_dictionary(speaker.name)) + self.log_debug(f"Set dictionaries for speakers in {time.time() - begin}") + + begin = time.time() + self.initialize_jobs() + self.log_debug(f"Initialized jobs in {time.time() - begin}") + + begin = time.time() + self.write_corpus_information() + self.log_debug(f"Wrote corpus information in {time.time() - begin}") + + begin = time.time() + self.create_corpus_split() + self.log_debug(f"Created corpus split directory in {time.time() - begin}") + + begin = time.time() + self.generate_features() + self.log_debug(f"Generated features in {time.time() - begin}") + + begin = time.time() + self.calculate_oovs_found() + self.log_debug(f"Calculated oovs found in {time.time() - begin}") + self.log_debug(f"Setting up corpus took {time.time() - all_begin} seconds") + + +class AcousticCorpus(AcousticCorpusPronunciationMixin, MfaWorker, TemporaryDirectoryMixin): + """ + Standalone class for working with acoustic corpora and pronunciation dictionaries + + Most functionality in MFA will use the :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin` class instead of this class. + + Parameters + ---------- + num_jobs: int + Number of jobs to use in processing the corpus + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin` + For dictionary and corpus parsing parameters + :class:`~montreal_forced_aligner.abc.MfaWorker` + For MFA processing parameters + :class:`~montreal_forced_aligner.abc.TemporaryDirectoryMixin` + For temporary directory parameters + """ + + def __init__(self, num_jobs=3, **kwargs): + super(AcousticCorpus, self).__init__(**kwargs) + self.num_jobs = num_jobs + + @property + def identifier(self) -> str: + """Identifier for the corpus""" + return self.data_source_identifier + + @property + def output_directory(self) -> str: + """Root temporary directory to store corpus and dictionary files""" + return os.path.join(self.temporary_directory, self.identifier) + + @property + def working_directory(self) -> str: + """Working directory to save temporary corpus and dictionary files""" + return self.output_directory + + def log_debug(self, message: str) -> None: + """ + Print a debug message + + Parameters + ---------- + message: str + Debug message to log + """ + print(message) + + def log_error(self, message: str) -> None: + """ + Print an error message + + Parameters + ---------- + message: str + Error message to log + """ + print(message) + + def log_info(self, message: str) -> None: + """ + Print an info message + + Parameters + ---------- + message: str + Info message to log + """ + print(message) + + def log_warning(self, message: str) -> None: + """ + Print a warning message + + Parameters + ---------- + message: str + Warning message to log + """ + print(message) diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index cebac349..59fe09c8 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -1,631 +1,183 @@ """Class definitions for corpora""" from __future__ import annotations -import logging -import multiprocessing as mp import os import random -import subprocess -import sys import time +from abc import ABCMeta, abstractmethod from collections import Counter -from queue import Empty -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Union +from typing import Optional, Union import yaml -from ..config import FeatureConfig -from ..config.dictionary_config import DictionaryConfig -from ..exceptions import CorpusError, KaldiProcessingError, TextGridParseError, TextParseError -from ..helper import output_mapping -from ..multiprocessing import Job -from ..multiprocessing.corpus import CorpusProcessWorker -from ..multiprocessing.features import calc_cmvn, compute_vad, mfcc -from ..multiprocessing.helper import Stopped -from ..utils import log_kaldi_errors, thirdparty_binary -from .classes import File, Speaker, Utterance, parse_file -from .helper import find_exts +from montreal_forced_aligner.abc import MfaWorker, TemporaryDirectoryMixin +from montreal_forced_aligner.corpus.classes import File, Speaker, Utterance +from montreal_forced_aligner.corpus.multiprocessing import Job +from montreal_forced_aligner.exceptions import CorpusError +from montreal_forced_aligner.helper import output_mapping +from montreal_forced_aligner.utils import Stopped -if TYPE_CHECKING: - from logging import Logger +__all__ = ["CorpusMixin"] - from ..dictionary import MultispeakerDictionary - -__all__ = ["Corpus"] - - -class Corpus: +class CorpusMixin(MfaWorker, TemporaryDirectoryMixin, metaclass=ABCMeta): """ - Class that stores information about the dataset to align. - - Corpus objects have a number of mappings from either utterances or speakers - to various properties, and mappings between utterances and speakers. - - See http://kaldi-asr.org/doc/data_prep.html for more information about - the files that are created by this class. + Mixin class for processing corpora + Notes + ----- + Using characters in files to specify speakers is generally finicky and leads to errors, so I would not + recommend using it. Additionally, consider it deprecated and could be removed in future versions Parameters ---------- - directory : str - Directory of the dataset to align - output_directory : str - Directory to store generated data for the Kaldi binaries - speaker_characters : int, optional - Number of characters in the filenames to count as the speaker ID, - if not specified, speaker IDs are generated from directory names - num_jobs : int, optional - Number of processes to use, defaults to 3 - sample_rate : int, optional - Default sample rate to use for feature generation, defaults to 16000 - debug : bool - Flag to enable debug mode, defaults to False - logger : :class:`~logging.Logger` - Logger to use - use_mp : bool - Flag to enable multiprocessing, defaults to True - punctuation : str, optional - Characters to treat as punctuation in parsing text - clitic_markers : str, optional - Characters to treat as clitic markers in parsing text - audio_directory : str, optional - Additional directory to parse for audio files - skip_load : bool - Flag to skip loading when initializing, defaults to False - parse_text_only_files : bool - Flag to parse text files that do not have associated sound files, defaults to False + corpus_directory: str + Path to corpus + speaker_characters: int or str, optional + Number of characters in the file name to specify the speaker + ignore_speakers: bool + Flag for whether to discard any parsed speaker information during top-level worker's processing + + See Also + -------- + :class:`~montreal_forced_aligner.abc.MfaWorker` + For MFA processing parameters + :class:`~montreal_forced_aligner.abc.TemporaryDirectoryMixin` + For temporary directory parameters + + Attributes + ---------- + speakers: dict[str, Speaker] + Dictionary of speakers in the corpus + files: dict[str, File] + Dictionary of files in the corpus + utterances: dict[str, Utterance] + Dictionary of utterances in the corpus + jobs: list[Job] + List of jobs for processing the corpus and splitting speakers + word_counts: Counter + Counts of words in the corpus + stopped: Stopped + Stop check for loading the corpus + decode_error_files: list[str] + List of text files that could not be loaded with utf8 + textgrid_read_errors: list[str] + List of TextGrid files that had an error in loading """ def __init__( self, - directory: str, - output_directory: str, - dictionary_config: Optional[DictionaryConfig] = None, + corpus_directory: str, speaker_characters: Union[int, str] = 0, - num_jobs: int = 3, - sample_rate: int = 16000, - debug: bool = False, - logger: Optional[Logger] = None, - use_mp: bool = True, - audio_directory: Optional[str] = None, - skip_load: bool = False, - parse_text_only_files: bool = False, ignore_speakers: bool = False, + **kwargs, ): - self.audio_directory = audio_directory - self.dictionary_config = dictionary_config - self.debug = debug - self.use_mp = use_mp - log_dir = os.path.join(output_directory, "logging") - os.makedirs(log_dir, exist_ok=True) - self.name = os.path.basename(directory) - self.log_file = os.path.join(log_dir, "corpus.log") - if logger is None: - self.logger = logging.getLogger("corpus_setup") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - else: - self.logger = logger - if not os.path.exists(directory): - raise CorpusError(f"The directory '{directory}' does not exist.") - if not os.path.isdir(directory): + if not os.path.exists(corpus_directory): + raise CorpusError(f"The directory '{corpus_directory}' does not exist.") + if not os.path.isdir(corpus_directory): raise CorpusError( - f"The specified path for the corpus ({directory}) is not a directory." + f"The specified path for the corpus ({corpus_directory}) is not a directory." ) - - num_jobs = max(num_jobs, 1) - if num_jobs == 1: - self.use_mp = False - self.original_num_jobs = num_jobs - self.logger.info("Setting up corpus information...") - self.directory = directory - self.output_directory = os.path.join(output_directory, "corpus_data") - self.temp_directory = os.path.join(self.output_directory, "temp") - os.makedirs(self.temp_directory, exist_ok=True) + self.speakers: dict[str, Speaker] = {} + self.files: dict[str, File] = {} + self.utterances: dict[str, Utterance] = {} + self.corpus_directory = corpus_directory self.speaker_characters = speaker_characters - if speaker_characters == 0: - self.speaker_directories = True - else: - self.speaker_directories = False - self.num_jobs = num_jobs - self.speakers: Dict[str, Speaker] = {} - self.files: Dict[str, File] = {} - self.utterances: Dict[str, Utterance] = {} - self.sound_file_errors = [] + self.ignore_speakers = ignore_speakers + self.word_counts = Counter() + self.stopped = Stopped() self.decode_error_files = [] - self.transcriptions_without_wavs = [] - self.no_transcription_files = [] self.textgrid_read_errors = {} - self.groups = [] - self.speaker_groups = [] - self.word_counts = Counter() - self.sample_rate = sample_rate - if self.use_mp: - self.stopped = Stopped() - else: - self.stopped = False - - self.skip_load = skip_load - self.utterances_time_sorted = False - self.parse_text_only_files = parse_text_only_files - self.feature_config = FeatureConfig() - self.vad_config = {"energy_threshold": 5.5, "energy_mean_scale": 0.5} - self.no_speakers = ignore_speakers - self.vad_segments = {} - if self.use_mp: - self.stopped = Stopped() - else: - self.stopped = False - if not self.skip_load: - self.load() + self.jobs: list[Job] = [] + super().__init__(**kwargs) - def normalized_text_iter(self, min_count: int = 1) -> Generator: - """ - Construct an iterator over the normalized texts in the corpus + @property + def features_directory(self) -> str: + """Feature directory of the corpus""" + return os.path.join(self.corpus_output_directory, "features") - Parameters - ---------- - min_count: int - Minimum word count to include in the output, otherwise will use OOV code, defaults to 1 + @property + def features_log_directory(self) -> str: + """Feature log directory""" + return os.path.join(self.split_directory, "log") - Yields - ------- - str - Normalized text - """ - unk_words = {k for k, v in self.word_counts.items() if v <= min_count} - for u in self.utterances.values(): - text = u.text.split() - new_text = [] - for t in text: - if u.speaker.dictionary is not None: - u.speaker.dictionary.to_int(t) - lookup = u.speaker.dictionary.split_clitics(t) - if lookup is None: - continue - else: - lookup = [t] - for item in lookup: - if item in unk_words: - new_text.append("") - elif ( - u.speaker.dictionary is not None and item not in u.speaker.dictionary.words - ): - new_text.append("") - else: - new_text.append(item) - yield " ".join(new_text) + @property + def split_directory(self) -> str: + """Directory used to store information split by job""" + return os.path.join(self.corpus_output_directory, f"split{self.num_jobs}") - def subset_directory(self, subset: Optional[int]) -> str: + def write_corpus_information(self) -> None: """ - Construct a subset directory for the corpus - - Parameters - ---------- - subset: int, optional - Number of utterances to include, if larger than the total number of utterance or not specified, the - split_directory is returned - - Returns - ------- - str - Path to subset directory + Output information to the temporary directory for later loading """ - if subset is None or subset > self.num_utterances or subset <= 0: - for j in self.jobs: - j.set_subset(None) - return self.split_directory - directory = os.path.join(self.output_directory, f"subset_{subset}") - self.create_subset(subset) - return directory + os.makedirs(self.split_directory, exist_ok=True) + self._write_speakers() + self._write_files() + self._write_utterances() + self._write_spk2utt() - def initialize_utt_fsts(self) -> None: - """ - Construct utterance FSTs - """ - for j in self.jobs: - j.output_utt_fsts(self) + def _write_spk2utt(self): + """Write spk2utt scp file for Kaldi""" + data = { + speaker.name: sorted(speaker.utterances.keys()) for speaker in self.speakers.values() + } + output_mapping(data, os.path.join(self.corpus_output_directory, "spk2utt.scp")) - def create_subset(self, subset: Optional[int]) -> None: - """ - Create a subset of utterances to use for training + def write_utt2spk(self): + """Write utt2spk scp file for Kaldi""" + data = {u.name: u.speaker.name for u in self.utterances.values()} + output_mapping(data, os.path.join(self.corpus_output_directory, "utt2spk.scp")) - Parameters - ---------- - subset: int - Number of utterances to include in subset - """ - subset_directory = os.path.join(self.output_directory, f"subset_{subset}") + def _write_speakers(self): + """Write speaker information for speeding up future runs""" + to_save = [] + for speaker in self.speakers.values(): + to_save.append(speaker.meta) + with open( + os.path.join(self.corpus_output_directory, "speakers.yaml"), "w", encoding="utf8" + ) as f: + yaml.safe_dump(to_save, f) - larger_subset_num = subset * 10 - if larger_subset_num < self.num_utterances: - # Get all shorter utterances that are not one word long - utts = sorted( - (utt for utt in self.utterances.values() if " " in utt.text), - key=lambda x: x.duration, - ) - larger_subset = utts[:larger_subset_num] - else: - larger_subset = sorted(self.utterances.values()) - random.seed(1234) # make it deterministic sampling - subset_utts = set(random.sample(larger_subset, subset)) - log_dir = os.path.join(subset_directory, "log") - os.makedirs(log_dir, exist_ok=True) + def _write_files(self): + """Write file information for speeding up future runs""" + to_save = [] + for file in self.files.values(): + to_save.append(file.meta) + with open( + os.path.join(self.corpus_output_directory, "files.yaml"), "w", encoding="utf8" + ) as f: + yaml.safe_dump(to_save, f) - for j in self.jobs: - j.set_subset(subset_utts) - j.output_to_directory(subset_directory) + def _write_utterances(self): + """Write utterance information for speeding up future runs""" + to_save = [] + for utterance in self.utterances.values(): + to_save.append(utterance.meta) + with open( + os.path.join(self.corpus_output_directory, "utterances.yaml"), "w", encoding="utf8" + ) as f: + yaml.safe_dump(to_save, f) - def load(self) -> None: - """ - Load the corpus - """ - loaded = self._load_from_temp() - if not loaded: - if self.use_mp: - self.logger.debug("Loading from source with multiprocessing") - self._load_from_source_mp() - else: - self.logger.debug("Loading from source without multiprocessing") - self._load_from_source() - else: - self.logger.debug("Successfully loaded from temporary files") + def create_corpus_split(self) -> None: + """Create split directory and output information from Jobs""" + split_dir = self.split_directory + os.makedirs(os.path.join(split_dir, "log"), exist_ok=True) + for job in self.jobs: + job.output_to_directory(split_dir) @property - def file_speaker_mapping(self) -> Dict[str, List[str]]: + def file_speaker_mapping(self) -> dict[str, list[str]]: """Speaker ordering for each file""" return {file_name: file.speaker_ordering for file_name, file in self.files.items()} - def _load_from_temp(self) -> bool: + def get_word_frequency(self) -> dict[str, float]: """ - Load a corpus from saved data in the temporary directory + Calculate the relative word frequency across all the texts in the corpus Returns ------- - bool - Whether loading from temporary files was successful - """ - begin_time = time.time() - for f in os.listdir(self.output_directory): - if f.startswith("split"): - old_num_jobs = int(f.replace("split", "")) - if old_num_jobs != self.num_jobs: - self.logger.info( - f"Found old run with {old_num_jobs} rather than the current {self.num_jobs}, " - f"setting to {old_num_jobs}. If you would like to use {self.num_jobs}, re-run the command " - f"with --clean." - ) - self.num_jobs = old_num_jobs - speakers_path = os.path.join(self.output_directory, "speakers.yaml") - files_path = os.path.join(self.output_directory, "files.yaml") - utterances_path = os.path.join(self.output_directory, "utterances.yaml") - - if not os.path.exists(speakers_path): - self.logger.debug(f"Could not find {speakers_path}, cannot load from temp") - return False - if not os.path.exists(files_path): - self.logger.debug(f"Could not find {files_path}, cannot load from temp") - return False - if not os.path.exists(utterances_path): - self.logger.debug(f"Could not find {utterances_path}, cannot load from temp") - return False - self.logger.debug("Loading from temporary files...") - - with open(speakers_path, "r", encoding="utf8") as f: - speaker_data = yaml.safe_load(f) - - for entry in speaker_data: - self.speakers[entry["name"]] = Speaker(entry["name"]) - self.speakers[entry["name"]].cmvn = entry["cmvn"] - - with open(files_path, "r", encoding="utf8") as f: - files_data = yaml.safe_load(f) - for entry in files_data: - self.files[entry["name"]] = File( - entry["wav_path"], entry["text_path"], entry["relative_path"] - ) - self.files[entry["name"]].speaker_ordering = [ - self.speakers[x] for x in entry["speaker_ordering"] - ] - self.files[entry["name"]].wav_info = entry["wav_info"] - - with open(utterances_path, "r", encoding="utf8") as f: - utterances_data = yaml.safe_load(f) - for entry in utterances_data: - s = self.speakers[entry["speaker"]] - f = self.files[entry["file"]] - u = Utterance( - s, - f, - begin=entry["begin"], - end=entry["end"], - channel=entry["channel"], - text=entry["text"], - ) - self.utterances[u.name] = u - if u.text: - self.word_counts.update(u.text.split()) - self.utterances[u.name].features = entry["features"] - self.utterances[u.name].ignored = entry["ignored"] - - self.logger.debug( - f"Loaded from corpus_data temp directory in {time.time()-begin_time} seconds" - ) - return True - - def _load_from_source_mp(self) -> None: - """ - Load a corpus using multiprocessing - """ - if self.stopped is None: - self.stopped = Stopped() - begin_time = time.time() - manager = mp.Manager() - job_queue = manager.Queue() - return_queue = manager.Queue() - return_dict = manager.dict() - return_dict["sound_file_errors"] = manager.list() - return_dict["decode_error_files"] = manager.list() - return_dict["textgrid_read_errors"] = manager.dict() - finished_adding = Stopped() - procs = [] - for _ in range(self.num_jobs): - p = CorpusProcessWorker( - job_queue, return_dict, return_queue, self.stopped, finished_adding - ) - procs.append(p) - p.start() - try: - - use_audio_directory = False - all_sound_files = {} - if self.audio_directory and os.path.exists(self.audio_directory): - use_audio_directory = True - for root, _, files in os.walk(self.audio_directory, followlinks=True): - ( - identifiers, - wav_files, - lab_files, - textgrid_files, - other_audio_files, - ) = find_exts(files) - wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} - other_audio_files = { - k: os.path.join(root, v) for k, v in other_audio_files.items() - } - all_sound_files.update(other_audio_files) - all_sound_files.update(wav_files) - - for root, _, files in os.walk(self.directory, followlinks=True): - identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts( - files - ) - relative_path = root.replace(self.directory, "").lstrip("/").lstrip("\\") - - if self.stopped.stop_check(): - break - if not use_audio_directory: - all_sound_files = {} - wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} - other_audio_files = { - k: os.path.join(root, v) for k, v in other_audio_files.items() - } - all_sound_files.update(other_audio_files) - all_sound_files.update(wav_files) - for file_name in identifiers: - if self.stopped.stop_check(): - break - wav_path = None - transcription_path = None - if file_name in all_sound_files: - wav_path = all_sound_files[file_name] - if file_name in lab_files: - lab_name = lab_files[file_name] - transcription_path = os.path.join(root, lab_name) - - elif file_name in textgrid_files: - tg_name = textgrid_files[file_name] - transcription_path = os.path.join(root, tg_name) - if wav_path is None and not self.parse_text_only_files: - self.transcriptions_without_wavs.append(transcription_path) - continue - if transcription_path is None: - self.no_transcription_files.append(wav_path) - job_queue.put( - ( - file_name, - wav_path, - transcription_path, - relative_path, - self.speaker_characters, - self.sample_rate, - self.dictionary_config, - ) - ) - - finished_adding.stop() - self.logger.debug("Finished adding jobs!") - job_queue.join() - - self.logger.debug("Waiting for workers to finish...") - for p in procs: - p.join() - - while True: - try: - file = return_queue.get(timeout=1) - if self.stopped.stop_check(): - continue - except Empty: - break - - self.add_file(file) - - if "error" in return_dict: - raise return_dict["error"][1] - - for k in ["sound_file_errors", "decode_error_files", "textgrid_read_errors"]: - if hasattr(self, k): - if return_dict[k]: - self.logger.info( - "There were some issues with files in the corpus. " - "Please look at the log file or run the validator for more information." - ) - self.logger.debug(f"{k} showed {len(return_dict[k])} errors:") - if k == "textgrid_read_errors": - getattr(self, k).update(return_dict[k]) - for f, e in return_dict[k].items(): - self.logger.debug(f"{f}: {e.error}") - else: - self.logger.debug(", ".join(return_dict[k])) - setattr(self, k, return_dict[k]) - - except KeyboardInterrupt: - self.logger.info( - "Detected ctrl-c, please wait a moment while we clean everything up..." - ) - self.stopped.stop() - finished_adding.stop() - job_queue.join() - self.stopped.set_sigint_source() - while True: - try: - _ = return_queue.get(timeout=1) - if self.stopped.stop_check(): - continue - except Empty: - break - finally: - - if self.stopped.stop_check(): - self.logger.info(f"Stopped parsing early ({time.time() - begin_time} seconds)") - if self.stopped.source(): - sys.exit(0) - else: - self.logger.debug( - f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds" - ) - - def _load_from_source(self) -> None: - """ - Load a corpus without using multiprocessing - """ - begin_time = time.time() - self.stopped = False - - all_sound_files = {} - use_audio_directory = False - if self.audio_directory and os.path.exists(self.audio_directory): - use_audio_directory = True - for root, _, files in os.walk(self.audio_directory, followlinks=True): - if self.stopped: - return - identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts( - files - ) - wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} - other_audio_files = { - k: os.path.join(root, v) for k, v in other_audio_files.items() - } - all_sound_files.update(other_audio_files) - all_sound_files.update(wav_files) - - for root, _, files in os.walk(self.directory, followlinks=True): - identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts(files) - relative_path = root.replace(self.directory, "").lstrip("/").lstrip("\\") - if self.stopped: - return - if not use_audio_directory: - all_sound_files = {} - wav_files = {k: os.path.join(root, v) for k, v in wav_files.items()} - other_audio_files = { - k: os.path.join(root, v) for k, v in other_audio_files.items() - } - all_sound_files.update(other_audio_files) - all_sound_files.update(wav_files) - for file_name in identifiers: - - wav_path = None - transcription_path = None - if file_name in all_sound_files: - wav_path = all_sound_files[file_name] - if file_name in lab_files: - lab_name = lab_files[file_name] - transcription_path = os.path.join(root, lab_name) - elif file_name in textgrid_files: - tg_name = textgrid_files[file_name] - transcription_path = os.path.join(root, tg_name) - if wav_path is None and not self.parse_text_only_files: - self.transcriptions_without_wavs.append(transcription_path) - continue - if transcription_path is None: - self.no_transcription_files.append(wav_path) - - try: - file = parse_file( - file_name, - wav_path, - transcription_path, - relative_path, - self.speaker_characters, - self.sample_rate, - self.dictionary_config, - ) - self.add_file(file) - except TextParseError as e: - self.decode_error_files.append(e) - except TextGridParseError as e: - self.textgrid_read_errors[e.file_name] = e - if self.decode_error_files or self.textgrid_read_errors: - self.logger.info( - "There were some issues with files in the corpus. " - "Please look at the log file or run the validator for more information." - ) - if self.decode_error_files: - self.logger.debug( - f"There were {len(self.decode_error_files)} errors decoding text files:" - ) - self.logger.debug(", ".join(self.decode_error_files)) - if self.textgrid_read_errors: - self.logger.debug( - f"There were {len(self.textgrid_read_errors)} errors decoding reading TextGrid files:" - ) - for f, e in self.textgrid_read_errors.items(): - self.logger.debug(f"{f}: {e.error}") - - self.logger.debug(f"Parsed corpus directory in {time.time()-begin_time} seconds") - - def add_file(self, file: File) -> None: - """ - Add a file to the corpus - - Parameters - ---------- - file: :class:`~montreal_forced_aligner.corpus.File` - File to be added - """ - self.files[file.name] = file - for speaker in file.speaker_ordering: - if speaker.name not in self.speakers: - self.speakers[speaker.name] = speaker - else: - self.speakers[speaker.name].merge(speaker) - for u in file.utterances.values(): - self.utterances[u.name] = u - if u.text: - self.word_counts.update(u.text.split()) - - def get_word_frequency(self) -> Dict[str, float]: - """ - Calculate the word frequency across all the texts in the corpus - - Returns - ------- - Dict[str, float] - PronunciationDictionary of words and their relative frequencies + dict[str, float] + Dictionary of words and their relative frequencies """ word_counts = Counter() for u in self.utterances.values(): @@ -644,7 +196,7 @@ def get_word_frequency(self) -> Dict[str, float]: return {k: v / sum(word_counts.values()) for k, v in word_counts.items()} @property - def word_set(self) -> List[str]: + def corpus_word_set(self) -> list[str]: """Set of words used in the corpus""" return sorted(self.word_counts) @@ -654,7 +206,7 @@ def add_utterance(self, utterance: Utterance) -> None: Parameters ---------- - utterance: :class:`~montreal_forced_aligner.corpus.Utterance` + utterance: :class:`~montreal_forced_aligner.corpus.classes.Utterance` Utterance to add """ self.utterances[utterance.name] = utterance @@ -669,7 +221,7 @@ def delete_utterance(self, utterance: Union[str, Utterance]) -> None: Parameters ---------- - utterance: :class:`~montreal_forced_aligner.corpus.Utterance` + utterance: :class:`~montreal_forced_aligner.corpus.classes.Utterance` Utterance to delete """ if isinstance(utterance, str): @@ -682,80 +234,119 @@ def initialize_jobs(self) -> None: """ Initialize the corpus's Jobs """ + self.log_info("Setting up training data...") if len(self.speakers) < self.num_jobs: self.num_jobs = len(self.speakers) self.jobs = [Job(i) for i in range(self.num_jobs)] job_ind = 0 - for s in self.speakers.values(): + for s in sorted(self.speakers.values()): self.jobs[job_ind].add_speaker(s) job_ind += 1 if job_ind == self.num_jobs: job_ind = 0 - def initialize_corpus( - self, - dictionary: Optional[MultispeakerDictionary] = None, - feature_config: Optional[FeatureConfig] = None, - ) -> None: + def add_file(self, file: File) -> None: """ - Initialize corpus for use + Add a file to the corpus Parameters ---------- - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary`, optional - PronunciationDictionary to use - feature_config: :class:`~montreal_forced_aligner.config.FeatureConfig`, optional - Feature configuration to use + file: :class:`~montreal_forced_aligner.corpus.classes.File` + File to be added """ - if not self.files: - raise CorpusError( - "There were no wav files found for transcribing this corpus. Please validate the corpus. " - "This error can also be caused if you're trying to find non-wav files without sox available " - "on the system path." + self.files[file.name] = file + for speaker in file.speaker_ordering: + if speaker.name not in self.speakers: + self.speakers[speaker.name] = speaker + else: + self.speakers[speaker.name].merge(speaker) + for u in file.utterances.values(): + self.utterances[u.name] = u + if u.text: + self.word_counts.update(u.text.split()) + + @property + def data_source_identifier(self) -> str: + """Corpus name""" + return os.path.basename(self.corpus_directory) + + def create_subset(self, subset: int) -> None: + """ + Create a subset of utterances to use for training + + Parameters + ---------- + subset: int + Number of utterances to include in subset + """ + subset_directory = os.path.join(self.corpus_output_directory, f"subset_{subset}") + + larger_subset_num = subset * 10 + if larger_subset_num < self.num_utterances: + # Get all shorter utterances that are not one word long + utts = sorted( + (utt for utt in self.utterances.values() if " " in utt.text), + key=lambda x: x.duration, ) + larger_subset = utts[:larger_subset_num] + else: + larger_subset = sorted(self.utterances.values()) + random.seed(1234) # make it deterministic sampling + subset_utts = set(random.sample(larger_subset, subset)) + log_dir = os.path.join(subset_directory, "log") + os.makedirs(log_dir, exist_ok=True) - if dictionary is not None: - for speaker in self.speakers.values(): - speaker.set_dictionary(dictionary.get_dictionary(speaker.name)) - self.initialize_jobs() for j in self.jobs: - j.set_feature_config(feature_config) - self.feature_config = feature_config - self.write() - self.split() - if self.feature_config is not None: - try: - self.generate_features() - except Exception as e: - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise + j.set_subset(subset_utts) + j.output_to_directory(subset_directory) @property def num_utterances(self) -> int: """Number of utterances in the corpus""" return len(self.utterances) - @property - def features_directory(self) -> str: - """Feature directory of the corpus""" - return os.path.join(self.output_directory, "features") - - @property - def features_log_directory(self) -> str: - """Feature log directory""" - return os.path.join(self.split_directory, "log") - - def speaker_utterance_info(self) -> str: + def subset_directory(self, subset: Optional[int]) -> str: """ - Construct message for analyzing high level detail about speakers and their utterances + Construct a subset directory for the corpus + + Parameters + ---------- + subset: int, optional + Number of utterances to include, if larger than the total number of utterance or not specified, the + split_directory is returned Returns ------- str - Analysis string + Path to subset directory """ + if subset is None or subset > self.num_utterances or subset <= 0: + for j in self.jobs: + j.set_subset(None) + return self.split_directory + directory = os.path.join(self.corpus_output_directory, f"subset_{subset}") + self.create_subset(subset) + return directory + + def _load_corpus(self) -> None: + """ + Load the corpus + """ + self.log_info("Setting up corpus information...") + loaded = self._load_corpus_from_temp() + if not loaded: + if self.use_mp: + self.log_debug("Loading from source with multiprocessing") + self._load_corpus_from_source_mp() + else: + self.log_debug("Loading from source without multiprocessing") + self._load_corpus_from_source() + else: + self.log_debug("Successfully loaded from temporary files") + if not self.files: + raise CorpusError( + "There were no files found for this corpus. Please validate the corpus." + ) num_speakers = len(self.speakers) if not num_speakers: raise CorpusError( @@ -763,180 +354,106 @@ def speaker_utterance_info(self) -> str: "and/or run the validation utility (mfa validate)." ) average_utterances = sum(len(x.utterances) for x in self.speakers.values()) / num_speakers - msg = ( + self.log_info( f"Number of speakers in corpus: {num_speakers}, " f"average number of utterances per speaker: {average_utterances}" ) - return msg - @property - def split_directory(self) -> str: - """Directory used to store information split by job""" - directory = os.path.join(self.output_directory, f"split{self.num_jobs}") - return directory - - def generate_features(self, overwrite: bool = False, compute_cmvn: bool = True) -> None: + def _load_corpus_from_temp(self) -> bool: """ - Generate features for the corpus - - Parameters - ---------- - overwrite: bool - Flag for whether to ignore existing files, defaults to False - compute_cmvn: bool - Flag for whether to compute CMVN, defaults to True - """ - if not overwrite and os.path.exists(os.path.join(self.output_directory, "feats.scp")): - return - self.logger.info(f"Generating base features ({self.feature_config.type})...") - if self.feature_config.type == "mfcc": - mfcc(self) - self.combine_feats() - if compute_cmvn: - self.logger.info("Calculating CMVN...") - calc_cmvn(self) - self.write() - self.split() - - def compute_vad(self) -> None: - """ - Compute Voice Activity Dectection features over the corpus - """ - if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")): - self.logger.info("VAD already computed, skipping!") - return - self.logger.info("Computing VAD...") - compute_vad(self) - - def combine_feats(self) -> None: - """ - Combine feature generation results and store relevant information - """ - split_directory = self.split_directory - ignore_check = [] - for job in self.jobs: - feats_paths = job.construct_path_dictionary(split_directory, "feats", "scp") - lengths_paths = job.construct_path_dictionary( - split_directory, "utterance_lengths", "scp" - ) - for dict_name in job.current_dictionary_names: - path = feats_paths[dict_name] - lengths_path = lengths_paths[dict_name] - if os.path.exists(lengths_path): - with open(lengths_path, "r") as inf: - for line in inf: - line = line.strip() - utt, length = line.split() - length = int(length) - if length < 13: # Minimum length to align one phone plus silence - self.utterances[utt].ignored = True - ignore_check.append(utt) - self.utterances[utt].feature_length = length - with open(path, "r") as inf: - for line in inf: - line = line.strip() - if line == "": - continue - f = line.split(maxsplit=1) - if self.utterances[f[0]].ignored: - continue - self.utterances[f[0]].features = f[1] - for u, utterance in self.utterances.items(): - if utterance.features is None: - utterance.ignored = True - ignore_check.append(u) - if ignore_check: - self.logger.warning( - "There were some utterances ignored due to short duration, see the log file for full " - "details or run `mfa validate` on the corpus." - ) - self.logger.debug( - f"The following utterances were too short to run alignment: " - f"{' ,'.join(ignore_check)}" - ) - self.write() - - def get_feat_dim(self) -> int: - """ - Calculate the feature dimension for the corpus + Load a corpus from saved data in the temporary directory Returns ------- - int - Dimension of feature vectors + bool + Whether loading from temporary files was successful """ - feature_string = self.jobs[0].construct_base_feature_string(self) - with open(os.path.join(self.features_log_directory, "feat-to-dim.log"), "w") as log_file: - dim_proc = subprocess.Popen( - [thirdparty_binary("feat-to-dim"), feature_string, "-"], - stdout=subprocess.PIPE, - stderr=log_file, - ) - stdout, stderr = dim_proc.communicate() - feats = stdout.decode("utf8").strip() - return int(feats) + begin_time = time.time() + if not os.path.exists(self.corpus_output_directory): + return False + for f in os.listdir(self.corpus_output_directory): + if f.startswith("split"): + old_num_jobs = int(f.replace("split", "")) + if old_num_jobs != self.num_jobs: + self.log_info( + f"Found old run with {old_num_jobs} rather than the current {self.num_jobs}, " + f"setting to {old_num_jobs}. If you would like to use {self.num_jobs}, re-run the command " + f"with --clean." + ) + self.num_jobs = old_num_jobs + speakers_path = os.path.join(self.corpus_output_directory, "speakers.yaml") + files_path = os.path.join(self.corpus_output_directory, "files.yaml") + utterances_path = os.path.join(self.corpus_output_directory, "utterances.yaml") - def write(self) -> None: - """ - Output information to the temporary directory for later loading - """ - self._write_speakers() - self._write_files() - self._write_utterances() - self._write_spk2utt() - self._write_feats() + if not os.path.exists(speakers_path): + self.log_debug(f"Could not find {speakers_path}, cannot load from temp") + return False + if not os.path.exists(files_path): + self.log_debug(f"Could not find {files_path}, cannot load from temp") + return False + if not os.path.exists(utterances_path): + self.log_debug(f"Could not find {utterances_path}, cannot load from temp") + return False + self.log_debug("Loading from temporary files...") - def _write_spk2utt(self): - """Write spk2utt scp file for Kaldi""" - data = { - speaker.name: sorted(speaker.utterances.keys()) for speaker in self.speakers.values() - } - output_mapping(data, os.path.join(self.output_directory, "spk2utt.scp")) + with open(speakers_path, "r", encoding="utf8") as f: + speaker_data = yaml.safe_load(f) - def write_utt2spk(self): - """Write utt2spk scp file for Kaldi""" - data = {u.name: u.speaker.name for u in self.utterances.values()} - output_mapping(data, os.path.join(self.output_directory, "utt2spk.scp")) + for entry in speaker_data: + self.speakers[entry["name"]] = Speaker(entry["name"]) + self.speakers[entry["name"]].cmvn = entry["cmvn"] - def _write_feats(self): - """Write feats scp file for Kaldi""" - if any(x.features is not None for x in self.utterances.values()): - with open(os.path.join(self.output_directory, "feats.scp"), "w", encoding="utf8") as f: - for utterance in self.utterances.values(): - if not utterance.features: - continue - f.write(f"{utterance.name} {utterance.features}\n") + with open(files_path, "r", encoding="utf8") as f: + files_data = yaml.safe_load(f) + for entry in files_data: + self.files[entry["name"]] = File( + entry["wav_path"], entry["text_path"], entry["relative_path"] + ) + self.files[entry["name"]].speaker_ordering = [ + self.speakers[x] for x in entry["speaker_ordering"] + ] + self.files[entry["name"]].wav_info = entry["wav_info"] - def _write_speakers(self): - """Write speaker information for speeding up future runs""" - to_save = [] - for speaker in self.speakers.values(): - to_save.append(speaker.meta) - with open(os.path.join(self.output_directory, "speakers.yaml"), "w", encoding="utf8") as f: - yaml.safe_dump(to_save, f) + with open(utterances_path, "r", encoding="utf8") as f: + utterances_data = yaml.safe_load(f) + for entry in utterances_data: + s = self.speakers[entry["speaker"]] + f = self.files[entry["file"]] + u = Utterance( + s, + f, + begin=entry["begin"], + end=entry["end"], + channel=entry["channel"], + text=entry["text"], + ) + self.utterances[u.name] = u + if u.text: + self.word_counts.update(u.text.split()) + self.utterances[u.name].features = entry["features"] + self.utterances[u.name].ignored = entry["ignored"] - def _write_files(self): - """Write file information for speeding up future runs""" - to_save = [] - for file in self.files.values(): - to_save.append(file.meta) - with open(os.path.join(self.output_directory, "files.yaml"), "w", encoding="utf8") as f: - yaml.safe_dump(to_save, f) + self.log_debug( + f"Loaded from corpus_data temp directory in {time.time() - begin_time} seconds" + ) + return True - def _write_utterances(self): - """Write utterance information for speeding up future runs""" - to_save = [] - for utterance in self.utterances.values(): - to_save.append(utterance.meta) - with open( - os.path.join(self.output_directory, "utterances.yaml"), "w", encoding="utf8" - ) as f: - yaml.safe_dump(to_save, f) + @property + def base_data_directory(self) -> str: + """Corpus data directory""" + return self.corpus_output_directory - def split(self) -> None: - """Create split directory and output information from Jobs""" - split_dir = self.split_directory - os.makedirs(os.path.join(split_dir, "log"), exist_ok=True) - self.logger.info("Setting up training data...") - for job in self.jobs: - job.output_to_directory(split_dir) + @property + def data_directory(self) -> str: + """Corpus data directory""" + return self.split_directory + + @abstractmethod + def _load_corpus_from_source_mp(self) -> None: + """Abstract method for loading a corpus with multiprocessing""" + ... + + @abstractmethod + def _load_corpus_from_source(self) -> None: + """Abstract method for loading a corpus without multiprocessing""" + ... diff --git a/montreal_forced_aligner/corpus/classes.py b/montreal_forced_aligner/corpus/classes.py index f82fbea9..5dc06f38 100644 --- a/montreal_forced_aligner/corpus/classes.py +++ b/montreal_forced_aligner/corpus/classes.py @@ -4,24 +4,21 @@ import os import sys import traceback -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union +from collections import Counter +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from praatio import textgrid from praatio.utilities.constants import Interval -from ..exceptions import CorpusError, TextGridParseError, TextParseError -from .helper import get_wav_info, load_text, parse_transcription +from montreal_forced_aligner.corpus.helper import get_wav_info, load_text, parse_transcription +from montreal_forced_aligner.exceptions import CorpusError, TextGridParseError, TextParseError if TYPE_CHECKING: - from ..config.align_config import AlignConfig - from ..config.dictionary_config import DictionaryConfig - from ..dictionary import Dictionary, DictionaryData - from ..textgrid import CtmType - from ..trainers import BaseTrainer, LdaTrainer, SatTrainer - - ConfigType = Union[BaseTrainer, AlignConfig] - FmllrConfigType = Union[SatTrainer, AlignConfig] - LdaConfigType = Union[LdaTrainer, AlignConfig] + from montreal_forced_aligner.abc import MetaDict + from montreal_forced_aligner.dictionary import DictionaryData + from montreal_forced_aligner.dictionary.mixins import SanitizeFunction + from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionaryMixin + from montreal_forced_aligner.textgrid import CtmInterval __all__ = ["parse_file", "File", "Speaker", "Utterance"] @@ -29,12 +26,11 @@ def parse_file( file_name: str, - wav_path: str, - text_path: str, + wav_path: Optional[str], + text_path: Optional[str], relative_path: str, speaker_characters: Union[int, str], - sample_rate: int = 16000, - dictionary_config: Optional[DictionaryConfig] = None, + sanitize_function: Optional[Callable] = None, stop_check: Optional[Callable] = None, ) -> File: """ @@ -52,24 +48,20 @@ def parse_file( Relative path from the corpus directory root speaker_characters: int, optional Number of characters in the file name to specify the speaker - sample_rate: int - Default sample rate for the corpus - punctuation: str - Orthographic characters to be treated as punctuation - clitic_markers: str - Orthographic characters to be treated as clitic markers + sanitize_function: Callable, optional + Function to sanitize words and strip punctuation stop_check: Callable Check whether to stop parsing early Returns ------- - :class:`~montreal_forced_aligner.corpus.File` + :class:`~montreal_forced_aligner.corpus.classes.File` Parsed file """ file = File(wav_path, text_path, relative_path=relative_path) if file.has_sound_file: root = os.path.dirname(wav_path) - file.wav_info = get_wav_info(wav_path, sample_rate=sample_rate) + file.wav_info = get_wav_info(wav_path) else: root = os.path.dirname(text_path) if not speaker_characters: @@ -85,7 +77,7 @@ def parse_file( root_speaker = Speaker(speaker_name) file.load_text( root_speaker=root_speaker, - dictionary_config=dictionary_config, + sanitize_function=sanitize_function, stop_check=stop_check, ) return file @@ -102,13 +94,13 @@ class Speaker: Attributes ---------- - utterances: Dict[str, :class:`~montreal_forced_aligner.corpus.Utterance`] + utterances: dict[str, :class:`~montreal_forced_aligner.corpus.classes.Utterance`] Utterances that the speaker is associated with cmvn: str, optional String pointing to any CMVN that has been calculated for this speaker dictionary: :class:`~montreal_forced_aligner.dictionary.PronunciationDictionary`, optional Pronunciation dictionary that the speaker is associated with - dictionary_data: DictionaryData, optional + dictionary_data: :class:`~montreal_forced_aligner.dictionary.DictionaryData`, optional Dictionary data from the speaker's dictionary """ @@ -116,22 +108,21 @@ def __init__(self, name): self.name = name self.utterances = {} self.cmvn = None - self.dictionary: Optional[Dictionary] = None + self.dictionary: Optional[PronunciationDictionaryMixin] = None self.dictionary_data: Optional[DictionaryData] = None + self.dictionary_name: Optional[str] = None + self.word_counts = Counter() def __getstate__(self): """Get dictionary for pickling""" - data = {"name": self.name, "cmvn": self.cmvn} - if self.dictionary_data is not None: - data["dictionary_data"] = self.dictionary_data + data = {"name": self.name, "cmvn": self.cmvn, "dictionary_name": self.dictionary_name} return data def __setstate__(self, state): """Recreate object following pickling""" self.name = state["name"] self.cmvn = state["cmvn"] - if "dictionary_data" in state: - self.dictionary_data = state["dictionary_data"] + self.dictionary_name = state["dictionary_name"] def __str__(self): """Return Speaker's name""" @@ -187,11 +178,13 @@ def add_utterance(self, utterance: Utterance): Parameters ---------- - utterance: :class:`~montreal_forced_aligner.corpus.Utterance` - Utterance + utterance: :class:`~montreal_forced_aligner.corpus.classes.Utterance` + Utterance to be added """ utterance.speaker = self self.utterances[utterance.name] = utterance + if utterance.text: + self.word_counts.update(utterance.text.split()) def delete_utterance(self, utterance: Utterance): """ @@ -199,7 +192,7 @@ def delete_utterance(self, utterance: Utterance): Parameters ---------- - utterance: :class:`~montreal_forced_aligner.corpus.Utterance` + utterance: :class:`~montreal_forced_aligner.corpus.classes.Utterance` Utterance to be deleted """ identifier = utterance.name @@ -212,29 +205,32 @@ def merge(self, speaker: Speaker): Parameters ---------- - speaker: :class:`~montreal_forced_aligner.corpus.Speaker` + speaker: :class:`~montreal_forced_aligner.corpus.classes.Speaker` Other speaker to take utterances from """ for u in speaker.utterances.values(): self.add_utterance(u) speaker.utterances = [] - def word_set(self) -> Set[str]: + def word_set(self) -> set[str]: """ Generate the word set of all the words in a speaker's utterances Returns ------- - Set[str] + set[str] Speaker's word set """ words = set() - for u in self.utterances.values(): - if u.text: - words.update(u.text.split()) + for word in self.word_counts: + if self.dictionary is not None: + word = self.dictionary._lookup(word) + words.update(word) + else: + words.add(word) return words - def set_dictionary(self, dictionary: Dictionary) -> None: + def set_dictionary(self, dictionary: PronunciationDictionaryMixin) -> None: """ Set the dictionary for the speaker @@ -244,10 +240,11 @@ def set_dictionary(self, dictionary: Dictionary) -> None: Pronunciation dictionary to associate with the speaker """ self.dictionary = dictionary + self.dictionary_name = dictionary.name self.dictionary_data = dictionary.data(self.word_set()) @property - def files(self) -> Set["File"]: + def files(self) -> set["File"]: """Files that the speaker is associated with""" files = set() for u in self.utterances.values(): @@ -301,10 +298,20 @@ def __init__( raise CorpusError("File objects must have either a wav_path or text_path") self.relative_path = relative_path self.wav_info = None - self.speaker_ordering: List[Speaker] = [] - self.utterances: Dict[str, Utterance] = {} + self.speaker_ordering: list[Speaker] = [] + self.utterances: dict[str, Utterance] = {} self.aligned = False + def has_fully_aligned_speaker(self, speaker: Speaker) -> bool: + for u in self.utterances.values(): + if u.speaker != speaker: + continue + if u.word_labels is None: + return False + if u.phone_labels is None: + return False + return True + def __repr__(self): """Representation of File objects""" return f'' @@ -484,7 +491,7 @@ def construct_output_path( def load_text( self, root_speaker: Optional[Speaker] = None, - dictionary_config: Optional[DictionaryConfig] = None, + sanitize_function: Optional[SanitizeFunction] = None, stop_check: Optional[Callable] = None, ) -> None: """ @@ -492,12 +499,10 @@ def load_text( Parameters ---------- - root_speaker: :class:`~montreal_forced_aligner.corpus.Speaker`, optional + root_speaker: :class:`~montreal_forced_aligner.corpus.classes.Speaker`, optional Speaker derived from the root directory, ignored for TextGrids - punctuation: str - Orthographic characters to treat as punctuation - clitic_markers: str - Orthographic characters to treat as clitic markers + sanitize_function: :class:`~montreal_forced_aligner.dictionary.mixins.SanitizeFunction`, optional + Function to sanitize words and strip punctuation stop_check: Callable Function to check whether this should break early """ @@ -506,7 +511,7 @@ def load_text( text = load_text(self.text_path) except UnicodeDecodeError: raise TextParseError(self.text_path) - words = parse_transcription(text, dictionary_config) + words = parse_transcription(text, sanitize_function) utterance = Utterance(speaker=root_speaker, file=self, text=" ".join(words)) self.add_utterance(utterance) elif self.text_type == "textgrid": @@ -540,7 +545,7 @@ def load_text( if stop_check is not None and stop_check(): return text = text.lower().strip() - words = parse_transcription(text, dictionary_config) + words = parse_transcription(text, sanitize_function) if not words: continue begin, end = round(begin, 4), round(end, 4) @@ -559,7 +564,7 @@ def add_speaker(self, speaker: Speaker) -> None: Parameters ---------- - speaker: :class:`~montreal_forced_aligner.corpus.Speaker` + speaker: :class:`~montreal_forced_aligner.corpus.classes.Speaker` Speaker to add """ if speaker not in self.speaker_ordering: @@ -571,7 +576,7 @@ def add_utterance(self, utterance: Utterance) -> None: Parameters ---------- - utterance: :class:`~montreal_forced_aligner.corpus.Utterance` + utterance: :class:`~montreal_forced_aligner.corpus.classes.Utterance` Utterance to add """ utterance.file = self @@ -584,7 +589,7 @@ def delete_utterance(self, utterance: Utterance) -> None: Parameters ---------- - utterance: :class:`~montreal_forced_aligner.corpus.Utterance` + utterance: :class:`~montreal_forced_aligner.corpus.classes.Utterance` Utterance to remove """ identifier = utterance.name @@ -650,9 +655,9 @@ class Utterance: Parameters ---------- - speaker: :class:`~montreal_forced_aligner.corpus.Speaker` + speaker: :class:`~montreal_forced_aligner.corpus.classes.Speaker` Speaker of the utterance - file: File + file: :class:`~montreal_forced_aligner.corpus.classes.File` File that the utterance belongs to begin: float, optional Start time of the utterance, @@ -680,11 +685,11 @@ class Utterance: Feature string reference to the computed features archive feature_length: int, optional Number of feature frames - phone_labels: CtmType, optional + phone_labels: list[:class:`~montreal_forced_aligner.data.CtmInterval`], optional Saved aligned phone labels - word_labels: CtmType, optional + word_labels: list[:class:`~montreal_forced_aligner.data.CtmInterval`], optional Saved aligned word labels - oovs: List[str] + oovs: list[str] Words not found in the dictionary for this utterance """ @@ -709,9 +714,9 @@ def __init__( self.ignored = False self.features = None self.feature_length = None - self.phone_labels: Optional[CtmType] = None - self.word_labels: Optional[CtmType] = None - self.oovs = [] + self.phone_labels: Optional[list[CtmInterval]] = None + self.word_labels: Optional[list[CtmInterval]] = None + self.oovs = set() self.speaker.add_utterance(self) self.file.add_utterance(self) @@ -767,7 +772,7 @@ def __eq__(self, other) -> bool: Parameters ---------- - other: :class:`~montreal_forced_aligner.corpus.Utterance` or str + other: :class:`~montreal_forced_aligner.corpus.classes.Utterance` or str Utterance to compare against Returns @@ -792,7 +797,7 @@ def __lt__(self, other) -> bool: Parameters ---------- - other: :class:`~montreal_forced_aligner.corpus.Utterance` or str + other: :class:`~montreal_forced_aligner.corpus.classes.Utterance` or str Utterance to compare against Returns @@ -816,7 +821,7 @@ def __lte__(self, other) -> bool: Parameters ---------- - other: :class:`~montreal_forced_aligner.corpus.Utterance` or str + other: :class:`~montreal_forced_aligner.corpus.classes.Utterance` or str Utterance to compare against Returns @@ -840,7 +845,7 @@ def __gt__(self, other) -> bool: Parameters ---------- - other: :class:`~montreal_forced_aligner.corpus.Utterance` or str + other: :class:`~montreal_forced_aligner.corpus.classes.Utterance` or str Utterance to compare against Returns @@ -865,7 +870,7 @@ def __gte__(self, other) -> bool: Parameters ---------- - other: :class:`~montreal_forced_aligner.corpus.Utterance` or str + other: :class:`~montreal_forced_aligner.corpus.classes.Utterance` or str Utterance to compare against Returns @@ -895,7 +900,7 @@ def duration(self) -> float: return self.file.duration @property - def meta(self) -> Dict[str, Any]: + def meta(self) -> MetaDict: """Metadata dictionary for the utterance""" return { "speaker": self.speaker.name, @@ -915,7 +920,7 @@ def set_speaker(self, speaker: Speaker): Parameters ---------- - speaker: :class:`~montreal_forced_aligner.corpus.Speaker` + speaker: :class:`~montreal_forced_aligner.corpus.classes.Speaker` New speaker """ self.speaker = speaker @@ -927,45 +932,45 @@ def is_segment(self): """Check if this utterance is a segment of a longer file""" return self.begin is not None and self.end is not None - def text_for_scp(self) -> List[str]: + def text_for_scp(self) -> list[str]: """ Generate the text for exporting to Kaldi's text scp Returns ------- - List[str] + list[str] List of words """ return self.text.split() - def text_int_for_scp(self) -> Optional[List[int]]: + def text_int_for_scp(self) -> Optional[list[int]]: """ Generate the text for exporting to Kaldi's text int scp Returns ------- - List[int] + list[int] List of word IDs, or None if the utterance's speaker doesn't have an associated dictionary """ - if self.speaker.dictionary is None: + if self.speaker.dictionary_data is None: return text = self.text_for_scp() new_text = [] for i, t in enumerate(text): - lookup = self.speaker.dictionary.to_int(t) + lookup = self.speaker.dictionary_data.to_int(t) for w in lookup: - if w == self.speaker.dictionary.oov_int: - self.oovs.append(text[i]) + if w == self.speaker.dictionary_data.oov_int: + self.oovs.add(text[i]) new_text.append(w) return new_text - def segment_for_scp(self) -> List[Any]: + def segment_for_scp(self) -> list[Any]: """ Generate data for Kaldi's segments scp file Returns ------- - List[Any] + list[Any] Segment data """ return [self.file.name, self.begin, self.end, self.channel] @@ -973,7 +978,9 @@ def segment_for_scp(self) -> List[Any]: @property def name(self): """The name of the utterance""" - base = f"{self.file_name}-{self.speaker_name}" + base = f"{self.file_name}" + if not base.startswith(f"{self.speaker_name}-"): + base = f"{self.speaker_name}-" + base if self.is_segment: - base = f"{self.file_name}-{self.speaker_name}-{self.begin}-{self.end}" + base = f"{self.file_name}-{self.begin}-{self.end}" return base.replace(" ", "-space-").replace(".", "-").replace("_", "-") diff --git a/montreal_forced_aligner/corpus/features.py b/montreal_forced_aligner/corpus/features.py new file mode 100644 index 00000000..a45b520d --- /dev/null +++ b/montreal_forced_aligner/corpus/features.py @@ -0,0 +1,836 @@ +"""Classes for configuring feature generation""" +from __future__ import annotations + +import os +import subprocess +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, NamedTuple, Union + +from montreal_forced_aligner.utils import thirdparty_binary + +if TYPE_CHECKING: + SpeakerCharacterType = Union[str, int] + from montreal_forced_aligner.abc import MetaDict + +__all__ = [ + "FeatureConfigMixin", + "mfcc_func", + "calc_fmllr_func", + "compute_vad_func", + "VadArguments", + "MfccArguments", + "CalcFmllrArguments", +] + + +class VadArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.corpus.features.compute_vad_func`""" + + log_path: str + dictionaries: list[str] + feats_scp_paths: dict[str, str] + vad_scp_paths: dict[str, str] + vad_options: MetaDict + + +class MfccArguments(NamedTuple): + """ + Arguments for :func:`~montreal_forced_aligner.corpus.features.mfcc_func` + """ + + log_path: str + dictionaries: list[str] + feats_scp_paths: dict[str, str] + lengths_paths: dict[str, str] + segment_paths: dict[str, str] + wav_paths: dict[str, str] + mfcc_options: MetaDict + + +class CalcFmllrArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.corpus.features.calc_fmllr_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ali_paths: dict[str, str] + ali_model_path: str + model_path: str + spk2utt_paths: dict[str, str] + trans_paths: dict[str, str] + fmllr_options: MetaDict + + +def make_safe(value: Any) -> str: + """ + Transform an arbitrary value into a string + + Parameters + ---------- + value: Any + Value to make safe + + Returns + ------- + str + Safe value + """ + if isinstance(value, bool): + return str(value).lower() + return str(value) + + +def mfcc_func( + log_path: str, + dictionaries: list[str], + feats_scp_paths: dict[str, str], + lengths_paths: dict[str, str], + segment_paths: dict[str, str], + wav_paths: dict[str, str], + mfcc_options: MetaDict, +) -> None: + """ + Multiprocessing function for generating MFCC features + + See Also + -------- + :meth:`.AcousticCorpusMixin.mfcc` + Main function that calls this function in parallel + :meth:`.AcousticCorpusMixin.mfcc_arguments` + Job method for generating arguments for this function + :kaldi_src:`compute-mfcc-feats` + Relevant Kaldi binary + :kaldi_src:`extract-segments` + Relevant Kaldi binary + :kaldi_src:`copy-feats` + Relevant Kaldi binary + :kaldi_src:`feat-to-len` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feats_scp_paths: dict[str, str] + Dictionary of feature scp files per dictionary name + lengths_paths: dict[str, str] + Dictionary of feature lengths files per dictionary name + segment_paths: dict[str, str] + Dictionary of segment scp files per dictionary name + wav_paths: dict[str, str] + Dictionary of sound file scp files per dictionary name + mfcc_options: dict[str, Any] + Options for MFCC generation + """ + with open(log_path, "w") as log_file: + for dict_name in dictionaries: + mfcc_base_command = [thirdparty_binary("compute-mfcc-feats"), "--verbose=2"] + raw_ark_path = feats_scp_paths[dict_name].replace(".scp", ".ark") + for k, v in mfcc_options.items(): + mfcc_base_command.append(f"--{k.replace('_', '-')}={make_safe(v)}") + if os.path.exists(segment_paths[dict_name]): + mfcc_base_command += ["ark:-", "ark:-"] + seg_proc = subprocess.Popen( + [ + thirdparty_binary("extract-segments"), + f"scp,p:{wav_paths[dict_name]}", + segment_paths[dict_name], + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + comp_proc = subprocess.Popen( + mfcc_base_command, + stdout=subprocess.PIPE, + stderr=log_file, + stdin=seg_proc.stdout, + env=os.environ, + ) + else: + mfcc_base_command += [f"scp,p:{wav_paths[dict_name]}", "ark:-"] + comp_proc = subprocess.Popen( + mfcc_base_command, stdout=subprocess.PIPE, stderr=log_file, env=os.environ + ) + copy_proc = subprocess.Popen( + [ + thirdparty_binary("copy-feats"), + "--compress=true", + "ark:-", + f"ark,scp:{raw_ark_path},{feats_scp_paths[dict_name]}", + ], + stdin=comp_proc.stdout, + stderr=log_file, + env=os.environ, + ) + copy_proc.communicate() + + utt_lengths_proc = subprocess.Popen( + [ + thirdparty_binary("feat-to-len"), + f"scp:{feats_scp_paths[dict_name]}", + f"ark,t:{lengths_paths[dict_name]}", + ], + stderr=log_file, + env=os.environ, + ) + utt_lengths_proc.communicate() + + +def compute_vad_func( + log_path: str, + dictionaries: list[str], + feats_scp_paths: dict[str, str], + vad_scp_paths: dict[str, str], + vad_options: MetaDict, +) -> None: + """ + Multiprocessing function to compute voice activity detection + + See Also + -------- + :meth:`.AcousticCorpusMixin.compute_vad` + Main function that calls this function in parallel + :meth:`.AcousticCorpusMixin.compute_vad_arguments` + Job method for generating arguments for this function + :kaldi_src:`compute-vad` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feats_scp_paths: dict[str, str] + Dictionary of feature scp files per dictionary name + vad_scp_paths: dict[str, str] + Dictionary of vad scp files per dictionary name + vad_options: dict[str, Any] + Options for VAD + """ + with open(log_path, "w") as log_file: + for dict_name in dictionaries: + feats_scp_path = feats_scp_paths[dict_name] + vad_scp_path = vad_scp_paths[dict_name] + vad_proc = subprocess.Popen( + [ + thirdparty_binary("compute-vad"), + f"--vad-energy-mean-scale={vad_options['energy_mean_scale']}", + f"--vad-energy-threshold={vad_options['energy_threshold']}", + f"scp:{feats_scp_path}", + f"ark,t:{vad_scp_path}", + ], + stderr=log_file, + env=os.environ, + ) + vad_proc.communicate() + + +def calc_fmllr_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ali_paths: dict[str, str], + ali_model_path: str, + model_path: str, + spk2utt_paths: dict[str, str], + trans_paths: dict[str, str], + fmllr_options: MetaDict, +) -> None: + """ + Multiprocessing function for calculating fMLLR transforms + + See Also + -------- + :meth:`.AcousticCorpusMixin.calc_fmllr` + Main function that calls this function in parallel + :meth:`.AcousticCorpusMixin.calc_fmllr_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-est-fmllr` + Relevant Kaldi binary + :kaldi_src:`gmm-est-fmllr-gpost` + Relevant Kaldi binary + :kaldi_src:`gmm-post-to-gpost` + Relevant Kaldi binary + :kaldi_src:`ali-to-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + :kaldi_src:`compose-transforms` + Relevant Kaldi binary + :kaldi_src:`transform-feats` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + ali_model_path: str + Path to the alignment acoustic model file + model_path: str + Path to the acoustic model file + spk2utt_paths: dict[str, str] + Dictionary of spk2utt scps per dictionary name + trans_paths: dict[str, str] + Dictionary of fMLLR transform archives per dictionary name + fmllr_options: dict[str, Any] + Options for fMLLR estimation + """ + with open(log_path, "w", encoding="utf8") as log_file: + log_file.writelines(f"{k}: {v}\n" for k, v in os.environ.items()) + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + ali_path = ali_paths[dict_name] + spk2utt_path = spk2utt_paths[dict_name] + trans_path = trans_paths[dict_name] + post_proc = subprocess.Popen( + [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + + weight_proc = subprocess.Popen( + [ + thirdparty_binary("weight-silence-post"), + "0.0", + fmllr_options["silence_csl"], + ali_model_path, + "ark:-", + "ark:-", + ], + stderr=log_file, + stdin=post_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + + if ali_model_path != model_path: + post_gpost_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-post-to-gpost"), + ali_model_path, + feature_string, + "ark:-", + "ark:-", + ], + stderr=log_file, + stdin=weight_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est-fmllr-gpost"), + "--verbose=4", + f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", + f"--spk2utt=ark:{spk2utt_path}", + model_path, + feature_string, + "ark,s,cs:-", + f"ark:{trans_path}", + ], + stderr=log_file, + stdin=post_gpost_proc.stdout, + env=os.environ, + ) + est_proc.communicate() + + else: + + if os.path.exists(trans_path): + cmp_trans_path = trans_paths[dict_name] + ".tmp" + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est-fmllr"), + "--verbose=4", + f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", + f"--spk2utt=ark:{spk2utt_path}", + model_path, + feature_string, + "ark:-", + "ark:-", + ], + stderr=log_file, + stdin=weight_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + comp_proc = subprocess.Popen( + [ + thirdparty_binary("compose-transforms"), + "--b-is-affine=true", + "ark:-", + f"ark:{trans_path}", + f"ark:{cmp_trans_path}", + ], + stderr=log_file, + stdin=est_proc.stdout, + env=os.environ, + ) + comp_proc.communicate() + + os.remove(trans_path) + os.rename(cmp_trans_path, trans_path) + else: + est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-est-fmllr"), + "--verbose=4", + f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", + f"--spk2utt=ark:{spk2utt_path}", + model_path, + feature_string, + "ark,s,cs:-", + f"ark:{trans_path}", + ], + stderr=log_file, + stdin=weight_proc.stdout, + env=os.environ, + ) + est_proc.communicate() + + +class FeatureConfigMixin: + """ + Class to store configuration information about MFCC generation + + Attributes + ---------- + feature_type : str + Feature type, defaults to "mfcc" + use_energy : bool + Flag for whether first coefficient should be used, defaults to False + frame_shift : int + number of milliseconds between frames, defaults to 10 + snip_edges : bool + Flag for enabling Kaldi's snip edges, should be better time precision + pitch : bool + Flag for including pitch in features, currently nonfunctional, defaults to False + low_frequency : int + Frequency floor + high_frequency : int + Frequency ceiling + sample_frequency : int + Sampling frequency + allow_downsample : bool + Flag for whether to allow downsampling, default is True + allow_upsample : bool + Flag for whether to allow upsampling, default is True + speaker_independent : bool + Flag for whether features are speaker independent, default is True + uses_cmvn : bool + Flag for whether to use CMVN, default is True + uses_deltas : bool + Flag for whether to use delta features, default is True + uses_splices : bool + Flag for whether to use splices and LDA transformations, default is False + uses_speaker_adaptation : bool + Flag for whether to use speaker adaptation, default is False + fmllr_update_type : str + Type of fMLLR estimation, defaults to "full" + silence_weight : float + Weight of silence in calculating LDA or fMLLR + splice_left_context : int or None + Number of frames to splice on the left for calculating LDA + splice_right_context : int or None + Number of frames to splice on the right for calculating LDA + """ + + def __init__( + self, + feature_type: str = "mfcc", + use_energy: bool = False, + frame_shift: int = 10, + snip_edges: bool = True, + pitch: bool = False, + low_frequency: int = 20, + high_frequency: int = 7800, + sample_frequency: int = 16000, + allow_downsample: bool = True, + allow_upsample: bool = True, + speaker_independent: bool = True, + uses_cmvn: bool = True, + uses_deltas: bool = True, + uses_splices: bool = False, + uses_voiced: bool = False, + uses_speaker_adaptation: bool = False, + fmllr_update_type: str = "full", + silence_weight: float = 0.0, + splice_left_context: int = 3, + splice_right_context: int = 3, + **kwargs, + ): + super().__init__(**kwargs) + self.feature_type = feature_type + self.use_energy = use_energy + self.frame_shift = frame_shift + self.snip_edges = snip_edges + self.pitch = pitch + self.low_frequency = low_frequency + self.high_frequency = high_frequency + self.sample_frequency = sample_frequency + self.allow_downsample = allow_downsample + self.allow_upsample = allow_upsample + self.speaker_independent = speaker_independent + self.uses_cmvn = uses_cmvn + self.uses_deltas = uses_deltas + self.uses_splices = uses_splices + self.uses_voiced = uses_voiced + self.uses_speaker_adaptation = uses_speaker_adaptation + self.fmllr_update_type = fmllr_update_type + self.silence_weight = silence_weight + self.splice_left_context = splice_left_context + self.splice_right_context = splice_right_context + + @property + def vad_options(self) -> MetaDict: + """Abstract method for VAD options""" + raise NotImplementedError + + @property + def alignment_model_path(self) -> str: # needed for fmllr + """Abstract method for alignment model path""" + raise NotImplementedError + + @property + def model_path(self) -> str: # needed for fmllr + """Abstract method for model path""" + raise NotImplementedError + + @property + @abstractmethod + def working_directory(self) -> str: + """Abstract method for working directory""" + ... + + @property + @abstractmethod + def corpus_output_directory(self) -> str: + """Abstract method for working directory of corpus""" + ... + + @property + @abstractmethod + def data_directory(self) -> str: + """Abstract method for corpus data directory""" + ... + + @property + def feature_options(self) -> MetaDict: + """Parameters for feature generation""" + options = { + "type": self.feature_type, + "use_energy": self.use_energy, + "frame_shift": self.frame_shift, + "snip_edges": self.snip_edges, + "low_frequency": self.low_frequency, + "high_frequency": self.high_frequency, + "sample_frequency": self.sample_frequency, + "allow_downsample": self.allow_downsample, + "allow_upsample": self.allow_upsample, + "pitch": self.pitch, + "uses_cmvn": self.uses_cmvn, + "uses_deltas": self.uses_deltas, + "uses_voiced": self.uses_voiced, + "uses_splices": self.uses_splices, + "uses_speaker_adaptation": self.uses_speaker_adaptation, + } + if self.uses_splices: + options.update( + { + "splice_left_context": self.splice_left_context, + "splice_right_context": self.splice_right_context, + } + ) + return options + + @abstractmethod + def calc_fmllr(self) -> None: + """Abstract method for calculating fMLLR transforms""" + ... + + @property + def fmllr_options(self) -> MetaDict: + """Options for use in calculating fMLLR transforms""" + return { + "fmllr_update_type": self.fmllr_update_type, + "silence_weight": self.silence_weight, + "silence_csl": getattr( + self, "silence_csl", "" + ), # If we have silence phones from a dictionary, use them + } + + @property + def mfcc_options(self) -> MetaDict: + """Parameters to use in computing MFCC features.""" + return { + "use-energy": self.use_energy, + "frame-shift": self.frame_shift, + "low-freq": self.low_frequency, + "high-freq": self.high_frequency, + "sample-frequency": self.sample_frequency, + "allow-downsample": self.allow_downsample, + "allow-upsample": self.allow_upsample, + "snip-edges": self.snip_edges, + } + + +class IvectorConfigMixin(FeatureConfigMixin): + """ + Mixin class for ivector features + + Parameters + ---------- + ivector_dimension: int + Dimension of ivectors + num_gselect: int + Gaussian-selection using diagonal model: number of Gaussians to select + posterior_scale: float + Scale on the acoustic posteriors, intended to account for inter-frame correlations + min_post : float + Minimum posterior to use (posteriors below this are pruned out) + max_count: int + The use of this option (e.g. --max-count 100) can make iVectors more consistent for different lengths of + utterance, by scaling up the prior term when the data-count exceeds this value. The data-count is after + posterior-scaling, so assuming the posterior-scale is 0.1, --max-count 100 starts having effect after 1000 + frames, or 10 seconds of data. + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.features.FeatureConfigMixin` + For feature generation parameters + """ + + def __init__( + self, + ivector_dimension=128, + num_gselect=20, + posterior_scale=1.0, + min_post=0.025, + max_count=100, + **kwargs, + ): + super().__init__(**kwargs) + self.ivector_dimension = ivector_dimension + self.num_gselect = num_gselect + self.posterior_scale = posterior_scale + self.min_post = min_post + self.max_count = max_count + + @abstractmethod + def extract_ivectors(self): + """Abstract method for extracting ivectors""" + ... + + @property + def ivector_options(self) -> MetaDict: + """Options for ivector training and extracting""" + return { + "num_gselect": self.num_gselect, + "posterior_scale": self.posterior_scale, + "min_post": self.min_post, + "silence_weight": self.silence_weight, + "max_count": self.max_count, + "ivector_dimension": self.ivector_dimension, + "silence_csl": getattr( + self, "silence_csl", "" + ), # If we have silence phones from a dictionary, use them, + } + + +class VadConfigMixin(FeatureConfigMixin): + """ + Abstract mixin class for performing voice activity detection + + Parameters + ---------- + use_energy: bool + Flag for using the first coefficient of MFCCs + energy_threshold: float + Energy threshold above which a frame will be counted as voiced + energy_mean_scale: float + Proportion of the mean energy of the file that should be added to the energy_threshold + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.features.FeatureConfigMixin` + For feature generation parameters + """ + + def __init__(self, use_energy=True, energy_threshold=5.5, energy_mean_scale=0.5, **kwargs): + super().__init__(**kwargs) + self.use_energy = use_energy + self.energy_threshold = energy_threshold + self.energy_mean_scale = energy_mean_scale + + @property + def vad_options(self) -> MetaDict: + """Options for performing VAD""" + return { + "energy_threshold": self.energy_threshold, + "energy_mean_scale": self.energy_mean_scale, + } + + +class ExtractIvectorsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.corpus.features.extract_ivectors_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ivector_options: MetaDict + ali_paths: dict[str, str] + ie_path: str + ivector_paths: dict[str, str] + weight_paths: dict[str, str] + model_path: str + dubm_path: str + + +def extract_ivectors_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ivector_options: MetaDict, + ali_paths: dict[str, str], + ie_path: str, + ivector_paths: dict[str, str], + weight_paths: dict[str, str], + model_path: str, + dubm_path: str, +) -> None: + """ + Multiprocessing function for extracting ivectors. + + See Also + -------- + :meth:`.IvectorCorpusMixin.extract_ivectors` + Main function that calls this function in parallel + :meth:`.IvectorCorpusMixin.extract_ivectors_arguments` + Job method for generating arguments for this function + :kaldi_src:`ivector-extract` + Relevant Kaldi binary + :kaldi_src:`gmm-global-get-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + :kaldi_src:`weight-post` + Relevant Kaldi binary + :kaldi_src:`post-to-weights` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ivector_options: dict[str, Any] + Options for ivector extraction + ali_paths: dict[str, str] + Dictionary of alignment archives per dictionary name + ie_path: str + Path to the ivector extractor file + ivector_paths: dict[str, str] + Dictionary of ivector archives per dictionary name + weight_paths: dict[str, str] + Dictionary of weighted archives per dictionary name + model_path: str + Path to the acoustic model file + dubm_path: str + Path to the DUBM file + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + ali_path = ali_paths[dict_name] + weight_path = weight_paths[dict_name] + ivectors_path = ivector_paths[dict_name] + feature_string = feature_strings[dict_name] + use_align = os.path.exists(ali_path) + if use_align: + ali_to_post_proc = subprocess.Popen( + [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + weight_silence_proc = subprocess.Popen( + [ + thirdparty_binary("weight-silence-post"), + str(ivector_options["silence_weight"]), + ivector_options["sil_phones"], + model_path, + "ark:-", + "ark:-", + ], + stderr=log_file, + stdin=ali_to_post_proc.stdout, + stdout=subprocess.PIPE, + env=os.environ, + ) + post_to_weight_proc = subprocess.Popen( + [thirdparty_binary("post-to-weights"), "ark:-", f"ark:{weight_path}"], + stderr=log_file, + stdin=weight_silence_proc.stdout, + env=os.environ, + ) + post_to_weight_proc.communicate() + + gmm_global_get_post_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-global-get-post"), + f"--n={ivector_options['num_gselect']}", + f"--min-post={ivector_options['min_post']}", + dubm_path, + feature_string, + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + if use_align: + weight_proc = subprocess.Popen( + [ + thirdparty_binary("weight-post"), + "ark:-", + f"ark,s,cs:{weight_path}", + "ark:-", + ], + stdin=gmm_global_get_post_proc.stdout, + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + extract_in = weight_proc.stdout + else: + extract_in = gmm_global_get_post_proc.stdout + extract_proc = subprocess.Popen( + [ + thirdparty_binary("ivector-extract"), + f"--acoustic-weight={ivector_options['posterior_scale']}", + "--compute-objf-change=true", + f"--max-count={ivector_options['max_count']}", + ie_path, + feature_string, + "ark,s,cs:-", + f"ark,t:{ivectors_path}", + ], + stderr=log_file, + stdin=extract_in, + env=os.environ, + ) + extract_proc.communicate() diff --git a/montreal_forced_aligner/corpus/helper.py b/montreal_forced_aligner/corpus/helper.py index ed3941d3..b9e5fa04 100644 --- a/montreal_forced_aligner/corpus/helper.py +++ b/montreal_forced_aligner/corpus/helper.py @@ -4,16 +4,14 @@ import os import shutil import subprocess -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union - -if TYPE_CHECKING: - from ..config.dictionary_config import DictionaryConfig +from typing import Any, Optional, Union import soundfile -from ..exceptions import SoxError +from montreal_forced_aligner.dictionary.mixins import SanitizeFunction +from montreal_forced_aligner.exceptions import SoxError -SoundFileInfoDict = Dict[str, Union[int, float, str]] +SoundFileInfoDict = dict[str, Union[int, float, str]] supported_audio_extensions = [".flac", ".ogg", ".aiff", ".mp3"] @@ -39,9 +37,7 @@ def load_text(path: str) -> str: return text -def parse_transcription( - text: str, dictionary_config: Optional[DictionaryConfig] = None -) -> List[str]: +def parse_transcription(text: str, sanitize_function=Optional[SanitizeFunction]) -> list[str]: """ Parse an orthographic transcription given punctuation and clitic markers @@ -49,22 +45,19 @@ def parse_transcription( ---------- text: str Orthographic text to parse - dictionary_config: Optional[DictionaryConfig] - Characters to treat as punctuation + sanitize_function: :class:`~montreal_forced_aligner.dictionary.mixins.SanitizeFunction`, optional + Function to sanitize words and strip punctuation Returns ------- List Parsed orthographic transcript """ - if dictionary_config is not None: - words = [dictionary_config.sanitize(x) for x in text.split()] + if sanitize_function is not None: words = [ - x - for x in words - if x - and x not in dictionary_config.clitic_markers - and x not in dictionary_config.compound_markers + sanitize_function(w) + for w in text.split() + if w not in sanitize_function.clitic_markers + sanitize_function.compound_markers ] else: words = text.split() @@ -72,8 +65,8 @@ def parse_transcription( def find_exts( - files: List[str], -) -> Tuple[List[str], Dict[str, str], Dict[str, str], Dict[str, str], Dict[str, str]]: + files: list[str], +) -> tuple[list[str], dict[str, str], dict[str, str], dict[str, str], dict[str, str]]: """ Find and group sound file extensions and transcription file extensions @@ -84,15 +77,15 @@ def find_exts( Returns ------- - List[str] + list[str] File name identifiers - Dict[str, str] + dict[str, str] Wav files - Dict[str, str] + dict[str, str] Lab and text files - Dict[str, str] + dict[str, str] TextGrid files - Dict[str, str] + dict[str, str] Other audio files (flac, mp3, etc) """ wav_files = {} @@ -120,7 +113,7 @@ def find_exts( return identifiers, wav_files, lab_files, textgrid_files, other_audio_files -def get_wav_info(file_path: str, sample_rate: int = 16000) -> dict: +def get_wav_info(file_path: str) -> dict[str, Any]: """ Get sound file information @@ -128,12 +121,10 @@ def get_wav_info(file_path: str, sample_rate: int = 16000) -> dict: ---------- file_path: str Sound file path - sample_rate: int - Default sample rate Returns ------- - Dict + dict[str, Any] Sound information for format, duration, number of channels, bit depth, and sox_string for use in Kaldi feature extraction if necessary """ @@ -190,5 +181,5 @@ def get_wav_info(file_path: str, sample_rate: int = 16000) -> dict: use_sox = True return_dict["sox_string"] = "" if use_sox: - return_dict["sox_string"] = f"sox {file_path} -t wav -b 16 -r {sample_rate} - |" + return_dict["sox_string"] = f"sox {file_path} -t wav -b 16 - |" return return_dict diff --git a/montreal_forced_aligner/corpus/ivector_corpus.py b/montreal_forced_aligner/corpus/ivector_corpus.py new file mode 100644 index 00000000..687944d3 --- /dev/null +++ b/montreal_forced_aligner/corpus/ivector_corpus.py @@ -0,0 +1,103 @@ +"""Classes for corpora that use ivectors as features""" +import os + +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin +from montreal_forced_aligner.corpus.features import ( + ExtractIvectorsArguments, + IvectorConfigMixin, + extract_ivectors_func, +) +from montreal_forced_aligner.utils import run_mp, run_non_mp + +__all__ = ["IvectorCorpusMixin"] + + +class IvectorCorpusMixin(AcousticCorpusMixin, IvectorConfigMixin): + """ + Abstract corpus mixin for corpora that extract ivectors + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusMixin` + For dictionary and corpus parsing parameters + :class:`~montreal_forced_aligner.corpus.features.IvectorConfigMixin` + For ivector extraction parameters + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @property + def ie_path(self): + """Ivector extractor ie path""" + raise NotImplementedError + + @property + def dubm_path(self): + """DUBM model path""" + raise + + def write_corpus_information(self) -> None: + """ + Output information to the temporary directory for later loading + """ + super().write_corpus_information() + self._write_utt2spk() + + def _write_utt2spk(self): + """Write feats scp file for Kaldi""" + with open( + os.path.join(self.corpus_output_directory, "utt2spk.scp"), "w", encoding="utf8" + ) as f: + for utterance in self.utterances.values(): + f.write(f"{utterance.name} {utterance.speaker.name}\n") + + def extract_ivectors_arguments(self) -> list[ExtractIvectorsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.corpus.features.extract_ivectors_func` + + Returns + ------- + list[ExtractIvectorsArguments] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + ExtractIvectorsArguments( + os.path.join(self.working_log_directory, f"extract_ivectors.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.ivector_options, + j.construct_path_dictionary(self.working_directory, "ali", "ark"), + self.ie_path, + j.construct_path_dictionary(self.working_directory, "ivectors", "scp"), + j.construct_path_dictionary(self.working_directory, "weights", "ark"), + self.model_path, + self.dubm_path, + ) + for j in self.jobs + ] + + def extract_ivectors(self) -> None: + """ + Multiprocessing function that extracts job_name-vectors. + + See Also + -------- + :func:`~montreal_forced_aligner.corpus.features.extract_ivectors_func` + Multiprocessing helper function for each job + :meth:`.IvectorCorpusMixin.extract_ivectors_arguments` + Job method for generating arguments for helper function + :kaldi_steps_sid:`extract_ivectors` + Reference Kaldi script + """ + + log_dir = self.working_log_directory + os.makedirs(log_dir, exist_ok=True) + + jobs = self.extract_ivectors_arguments() + if self.use_mp: + run_mp(extract_ivectors_func, jobs, log_dir) + else: + run_non_mp(extract_ivectors_func, jobs, log_dir) diff --git a/montreal_forced_aligner/corpus/multiprocessing.py b/montreal_forced_aligner/corpus/multiprocessing.py new file mode 100644 index 00000000..20ccdd6f --- /dev/null +++ b/montreal_forced_aligner/corpus/multiprocessing.py @@ -0,0 +1,767 @@ +""" +Corpus loading worker +--------------------- + + +""" +from __future__ import annotations + +import multiprocessing as mp +import os +import sys +import traceback +from queue import Empty +from typing import TYPE_CHECKING, Collection, Optional, Union + +from montreal_forced_aligner.exceptions import TextGridParseError, TextParseError +from montreal_forced_aligner.helper import output_mapping + +if TYPE_CHECKING: + + from montreal_forced_aligner.abc import OneToManyMappingType, OneToOneMappingType + from montreal_forced_aligner.corpus.helper import SoundFileInfoDict + + FileInfoDict = dict[ + str, Union[str, SoundFileInfoDict, OneToOneMappingType, OneToManyMappingType] + ] + from montreal_forced_aligner.abc import MappingType, ReversedMappingType, WordsType + from montreal_forced_aligner.corpus.classes import File, Speaker, Utterance + from montreal_forced_aligner.dictionary import DictionaryData, PronunciationDictionaryMixin + from montreal_forced_aligner.utils import Stopped + + +__all__ = ["CorpusProcessWorker", "Job"] + + +class CorpusProcessWorker(mp.Process): + """ + Multiprocessing corpus loading worker + + Attributes + ---------- + job_q: :class:`~multiprocessing.Queue` + Job queue for files to process + return_dict: dict + Dictionary to catch errors + return_q: :class:`~multiprocessing.Queue` + Return queue for processed Files + stopped: :func:`~montreal_forced_aligner.utils.Stopped` + Stop check for whether corpus loading should exit + finished_adding: :class:`~montreal_forced_aligner.utils.Stopped` + Signal that the main thread has stopped adding new files to be processed + """ + + def __init__( + self, + job_q: mp.Queue, + return_dict: dict, + return_q: mp.Queue, + stopped: Stopped, + finished_adding: Stopped, + ): + mp.Process.__init__(self) + self.job_q = job_q + self.return_dict = return_dict + self.return_q = return_q + self.stopped = stopped + self.finished_adding = finished_adding + + def run(self) -> None: + """ + Run the corpus loading job + """ + from ..corpus.classes import parse_file + + while True: + try: + arguments = self.job_q.get(timeout=1) + except Empty: + if self.finished_adding.stop_check(): + break + continue + self.job_q.task_done() + if self.stopped.stop_check(): + continue + try: + file = parse_file(*arguments, stop_check=self.stopped.stop_check) + self.return_q.put(file) + except TextParseError as e: + self.return_dict["decode_error_files"].append(e) + except TextGridParseError as e: + self.return_dict["textgrid_read_errors"][e.file_name] = e + except Exception: + self.stopped.stop() + self.return_dict["error"] = arguments, Exception( + traceback.format_exception(*sys.exc_info()) + ) + return + + +class Job: + """ + Class representing information about corpus jobs that will be run in parallel. + Jobs have a set of speakers that they will process, along with all files and utterances associated with that speaker. + As such, Jobs also have a set of dictionaries that the speakers use, and argument outputs are largely dependent on + the pronunciation dictionaries in use. + + Parameters + ---------- + name: int + Job number is the job's identifier + + Attributes + ---------- + speakers: list[:class:`~montreal_forced_aligner.corpus.classes.Speaker`] + List of speakers associated with this job + dictionaries: set[:class:`~montreal_forced_aligner.dictionary.PronunciationDictionary`] + Set of dictionaries that the job's speakers use + subset_utts: set[:class:`~montreal_forced_aligner.corpus.classes.Utterance`] + When trainers are just using a subset of the corpus, the subset of utterances on each job will be set and used to + filter the job's utterances + subset_speakers: set[:class:`~montreal_forced_aligner.corpus.classes.Speaker`] + When subset_utts is set, this property will be calculated as the subset of speakers that the utterances correspond to + subset_dictionaries: set[:class:`~montreal_forced_aligner.dictionary.PronunciationDictionary`] + Subset of dictionaries that the subset of speakers use + + """ + + name: int + speakers: list[Speaker] + subset_utts: set[Utterance] + subset_speakers: set[Speaker] + dictionaries: set[PronunciationDictionaryMixin] + subset_dictionaries: set[PronunciationDictionaryMixin] + + def __init__(self, name: int): + self.name = name + self.speakers = [] + self.dictionaries = set() + + self.subset_utts = set() + self.subset_speakers = set() + self.subset_dictionaries = set() + + def add_speaker(self, speaker: Speaker) -> None: + """ + Add a speaker to a job + + Parameters + ---------- + speaker: :class:`~montreal_forced_aligner.corpus.classes.Speaker` + Speaker to add + """ + self.speakers.append(speaker) + self.dictionaries.add(speaker.dictionary) + + def set_subset(self, subset_utts: Optional[Collection[Utterance]]) -> None: + """ + Set the current subset for the trainer + + Parameters + ---------- + subset_utts: Collection[:class:`~montreal_forced_aligner.corpus.classes.Utterance`], optional + Subset of utterances for this job to use + """ + if subset_utts is None: + self.subset_utts = set() + self.subset_speakers = set() + self.subset_dictionaries = set() + else: + self.subset_utts = set(u for u in subset_utts if u.speaker in self.speakers) + self.subset_speakers = {u.speaker for u in subset_utts if u.speaker in self.speakers} + self.subset_dictionaries = {s.dictionary for s in self.subset_speakers} + + def text_scp_data(self) -> dict[str, dict[str, list[str]]]: + """ + Generate the job's data for Kaldi's text scp files + + Returns + ------- + dict[str, dict[str, list[str]]] + Text for each utterance, per dictionary name + """ + data = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + for utt in utt_data.values(): + if not utt.text: + continue + data[dict_name][utt.name] = " ".join(map(str, utt.text_for_scp())) + return data + + def text_int_scp_data(self) -> dict[str, dict[str, str]]: + """ + Generate the job's data for Kaldi's text int scp files + + Returns + ------- + dict[str, dict[str, str]] + Text converted to integer IDs for each utterance, per dictionary name + """ + data = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + for utt in utt_data.values(): + if utt.speaker.dictionary is None: + continue + if not utt.text: + continue + data[dict_name][utt.name] = " ".join(map(str, utt.text_int_for_scp())) + return data + + def wav_scp_data(self) -> dict[str, dict[str, str]]: + """ + Generate the job's data for Kaldi's wav scp files + + Returns + ------- + dict[str, dict[str, str]] + Wav scp strings for each file, per dictionary name + """ + data = {} + done = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + done[dict_name] = set() + for utt in utt_data.values(): + if not utt.is_segment: + data[dict_name][utt.name] = utt.file.for_wav_scp() + elif utt.file.name not in done: + data[dict_name][utt.file.name] = utt.file.for_wav_scp() + done[dict_name].add(utt.file.name) + return data + + def utt2spk_scp_data(self) -> dict[str, dict[str, str]]: + """ + Generate the job's data for Kaldi's utt2spk scp files + + Returns + ------- + dict[str, dict[str, str]] + Utterance to speaker mapping, per dictionary name + """ + data = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + for utt in utt_data.values(): + data[dict_name][utt.name] = utt.speaker_name + return data + + def feat_scp_data(self) -> dict[str, dict[str, str]]: + """ + Generate the job's data for Kaldi's feature scp files + + Returns + ------- + dict[str, dict[str, str]] + Utterance to feature archive ID mapping, per dictionary name + """ + data = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + for utt in utt_data.values(): + if not utt.features: + continue + data[dict_name][utt.name] = utt.features + return data + + def spk2utt_scp_data(self) -> dict[str, dict[str, list[str]]]: + """ + Generate the job's data for Kaldi's spk2utt scp files + + Returns + ------- + dict[str, dict[str, list[str]]] + Speaker to utterance mapping, per dictionary name + """ + data = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + for utt in utt_data.values(): + if utt.speaker.name not in data[dict_name]: + data[dict_name][utt.speaker.name] = [] + data[dict_name][utt.speaker.name].append(str(utt)) + for k, v in data.items(): + for s, utts in v.items(): + data[k][s] = sorted(utts) + return data + + def cmvn_scp_data(self) -> dict[str, dict[str, str]]: + """ + Generate the job's data for Kaldi's CMVN scp files + + Returns + ------- + dict[str, dict[str, str]] + Speaker to CMVN mapping, per dictionary name + """ + data = {} + for s in self.speakers: + if s.dictionary is None: + key = None + else: + key = s.dictionary.name + if key not in data: + data[key] = {} + if self.subset_speakers and s not in self.subset_speakers: + continue + if s.cmvn: + data[key][s.name] = s.cmvn + return data + + def segments_scp_data(self) -> dict[str, dict[str, str]]: + """ + Generate the job's data for Kaldi's segments scp files + + Returns + ------- + dict[str, dict[str, str]] + Utterance to segment mapping, per dictionary name + """ + data = {} + utts = self.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = {} + for utt in utt_data.values(): + if not utt.is_segment: + continue + data[dict_name][utt.name] = utt.segment_for_scp() + return data + + def construct_path_dictionary( + self, directory: str, identifier: str, extension: str + ) -> dict[str, str]: + """ + Helper function for constructing dictionary-dependent paths for the Job + + Parameters + ---------- + directory: str + Directory to use as the root + identifier: str + Identifier for the path name, like ali or acc + extension: str + Extension of the path, like .scp or .ark + + Returns + ------- + dict[str, str] + Path for each dictionary + """ + output = {} + for dict_name in self.current_dictionary_names: + output[dict_name] = os.path.join( + directory, f"{identifier}.{dict_name}.{self.name}.{extension}" + ) + return output + + def construct_dictionary_dependent_paths( + self, directory: str, identifier: str, extension: str + ) -> dict[str, str]: + """ + Helper function for constructing paths that depend only on the dictionaries of the job, and not the job name itself. + These paths should be merged with all other jobs to get a full set of dictionary paths. + + Parameters + ---------- + directory: str + Directory to use as the root + identifier: str + Identifier for the path name, like ali or acc + extension: str + Extension of the path, like .scp or .ark + + Returns + ------- + dict[str, str] + Path for each dictionary + """ + output = {} + for dict_name in self.current_dictionary_names: + output[dict_name] = os.path.join(directory, f"{identifier}.{dict_name}.{extension}") + return output + + @property + def dictionary_count(self): + """Number of dictionaries currently used""" + if self.subset_dictionaries: + return len(self.subset_dictionaries) + return len(self.dictionaries) + + @property + def current_dictionaries(self) -> Collection[PronunciationDictionaryMixin]: + """Current dictionaries depending on whether a subset is being used""" + if self.subset_dictionaries: + return self.subset_dictionaries + return self.dictionaries + + @property + def current_dictionary_names(self) -> list[Optional[str]]: + """Current dictionary names depending on whether a subset is being used""" + if self.subset_dictionaries: + return sorted(x.name for x in self.subset_dictionaries) + if self.dictionaries == {None}: + return [None] + return sorted(x.name for x in self.dictionaries) + + def word_boundary_int_files(self) -> dict[str, str]: + """ + Generate mapping for dictionaries to word boundary int files + + Returns + ------- + dict[str, str] + Per dictionary word boundary int files + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = os.path.join(dictionary.phones_dir, "word_boundary.int") + return data + + def reversed_phone_mappings(self) -> dict[str, ReversedMappingType]: + """ + Generate mapping for dictionaries to reversed phone mapping + + Returns + ------- + dict[str, ReversedMappingType] + Per dictionary reversed phone mapping + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.reversed_phone_mapping + return data + + def reversed_word_mappings(self) -> dict[str, ReversedMappingType]: + """ + Generate mapping for dictionaries to reversed word mapping + + Returns + ------- + dict[str, ReversedMappingType] + Per dictionary reversed word mapping + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.reversed_word_mapping + return data + + def words_mappings(self) -> dict[str, MappingType]: + """ + Generate mapping for dictionaries to word mapping + + Returns + ------- + dict[str, MappingType] + Per dictionary word mapping + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.words_mapping + return data + + def words(self) -> dict[str, WordsType]: + """ + Generate mapping for dictionaries to words + + Returns + ------- + dict[str, WordsType] + Per dictionary words + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.words + return data + + def punctuation(self): + """ + Generate mapping for dictionaries to punctuation + + Returns + ------- + dict[str, str] + Per dictionary punctuation + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.punctuation + return data + + def clitic_set(self) -> dict[str, set[str]]: + """ + Generate mapping for dictionaries to clitic sets + + Returns + ------- + dict[str, str] + Per dictionary clitic sets + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.clitic_set + return data + + def clitic_markers(self) -> dict[str, list[str]]: + """ + Generate mapping for dictionaries to clitic markers + + Returns + ------- + dict[str, str] + Per dictionary clitic markers + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.clitic_markers + return data + + def compound_markers(self) -> dict[str, list[str]]: + """ + Generate mapping for dictionaries to compound markers + + Returns + ------- + dict[str, str] + Per dictionary compound markers + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.compound_markers + return data + + def strip_diacritics(self) -> dict[str, list[str]]: + """ + Generate mapping for dictionaries to diacritics to strip + + Returns + ------- + dict[str, list[str]] + Per dictionary strip diacritics + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.strip_diacritics + return data + + def oov_codes(self) -> dict[str, str]: + """ + Generate mapping for dictionaries to oov symbols + + Returns + ------- + dict[str, str] + Per dictionary oov symbols + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.oov_word + return data + + def oov_ints(self) -> dict[str, int]: + """ + Generate mapping for dictionaries to oov ints + + Returns + ------- + dict[str, int] + Per dictionary oov ints + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.oov_int + return data + + def positions(self) -> dict[str, list[str]]: + """ + Generate mapping for dictionaries to positions + + Returns + ------- + dict[str, list[str]] + Per dictionary positions + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.positions + return data + + def silences(self) -> dict[str, set[str]]: + """ + Generate mapping for dictionaries to silence symbols + + Returns + ------- + dict[str, set[str]] + Per dictionary silence symbols + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.silences + return data + + def multilingual_ipa(self) -> dict[str, bool]: + """ + Generate mapping for dictionaries to multilingual IPA flags + + Returns + ------- + dict[str, bool] + Per dictionary multilingual IPA flags + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.multilingual_ipa + return data + + def job_utts(self) -> dict[str, dict[str, Utterance]]: + """ + Generate utterances by dictionary name for the Job + + Returns + ------- + dict[str, dict[str, :class:`~montreal_forced_aligner.corpus.classes.Utterance`]] + Mapping of dictionary name to Utterance mappings + """ + data = {} + if self.subset_utts: + utterances = self.subset_utts + else: + utterances = set() + for s in self.speakers: + utterances.update(s.utterances.values()) + for u in utterances: + if u.ignored: + continue + if u.speaker.dictionary is None: + dict_name = None + else: + dict_name = u.speaker.dictionary.name + if dict_name not in data: + data[dict_name] = {} + data[dict_name][u.name] = u + + return data + + def job_files(self) -> dict[str, File]: + """ + Generate files for the Job + + Returns + ------- + dict[str, :class:`~montreal_forced_aligner.corpus.classes.File`] + Mapping of file name to File objects + """ + data = {} + if self.subset_utts: + utterances = self.subset_utts + else: + utterances = set() + for s in self.speakers: + utterances.update(s.utterances.values()) + for u in utterances: + if u.ignored: + continue + data[u.file_name] = u.file + return data + + def job_speakers(self) -> dict[str, Speaker]: + """ + Generate files for the Job + + Returns + ------- + dict[str, :class:`~montreal_forced_aligner.corpus.classes.Speaker`] + Mapping of file name to File objects + """ + data = {} + if self.subset_speakers: + speakers = self.subset_speakers + else: + speakers = self.speakers + for s in speakers: + data[s.name] = s + return data + + def dictionary_data(self) -> dict[str, DictionaryData]: + """ + Generate dictionary data for the job + + Returns + ------- + dict[str, DictionaryData] + Mapping of dictionary name to dictionary data + """ + data = {} + for dictionary in self.current_dictionaries: + data[dictionary.name] = dictionary.data() + return data + + def output_to_directory(self, split_directory: str) -> None: + """ + Output job information to a directory + + Parameters + ---------- + split_directory: str + Directory to output to + """ + wav = self.wav_scp_data() + for dict_name, scp in wav.items(): + wav_scp_path = os.path.join(split_directory, f"wav.{dict_name}.{self.name}.scp") + output_mapping(scp, wav_scp_path, skip_safe=True) + + spk2utt = self.spk2utt_scp_data() + for dict_name, scp in spk2utt.items(): + spk2utt_scp_path = os.path.join( + split_directory, f"spk2utt.{dict_name}.{self.name}.scp" + ) + output_mapping(scp, spk2utt_scp_path) + + feats = self.feat_scp_data() + for dict_name, scp in feats.items(): + feats_scp_path = os.path.join(split_directory, f"feats.{dict_name}.{self.name}.scp") + output_mapping(scp, feats_scp_path) + + cmvn = self.cmvn_scp_data() + for dict_name, scp in cmvn.items(): + cmvn_scp_path = os.path.join(split_directory, f"cmvn.{dict_name}.{self.name}.scp") + output_mapping(scp, cmvn_scp_path) + + utt2spk = self.utt2spk_scp_data() + for dict_name, scp in utt2spk.items(): + utt2spk_scp_path = os.path.join( + split_directory, f"utt2spk.{dict_name}.{self.name}.scp" + ) + output_mapping(scp, utt2spk_scp_path) + + segments = self.segments_scp_data() + for dict_name, scp in segments.items(): + segments_scp_path = os.path.join( + split_directory, f"segments.{dict_name}.{self.name}.scp" + ) + output_mapping(scp, segments_scp_path) + + text_scp = self.text_scp_data() + for dict_name, scp in text_scp.items(): + if not scp: + continue + text_scp_path = os.path.join(split_directory, f"text.{dict_name}.{self.name}.scp") + output_mapping(scp, text_scp_path) + + text_int = self.text_int_scp_data() + for dict_name, scp in text_int.items(): + if dict_name is None: + continue + if not scp: + continue + text_int_scp_path = os.path.join( + split_directory, f"text.{dict_name}.{self.name}.int.scp" + ) + output_mapping(scp, text_int_scp_path, skip_safe=True) diff --git a/montreal_forced_aligner/corpus/text_corpus.py b/montreal_forced_aligner/corpus/text_corpus.py new file mode 100644 index 00000000..474f2a33 --- /dev/null +++ b/montreal_forced_aligner/corpus/text_corpus.py @@ -0,0 +1,321 @@ +"""Class definitions for corpora""" +from __future__ import annotations + +import multiprocessing as mp +import os +import sys +import time +from queue import Empty + +from montreal_forced_aligner.abc import MfaWorker, TemporaryDirectoryMixin +from montreal_forced_aligner.corpus.base import CorpusMixin +from montreal_forced_aligner.corpus.classes import parse_file +from montreal_forced_aligner.corpus.helper import find_exts +from montreal_forced_aligner.corpus.multiprocessing import CorpusProcessWorker +from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin +from montreal_forced_aligner.exceptions import TextGridParseError, TextParseError +from montreal_forced_aligner.utils import Stopped + + +class TextCorpusMixin(CorpusMixin): + """ + Abstract mixin class for processing text corpora + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.base.CorpusMixin` + For corpus parsing parameters + """ + + def _load_corpus_from_source_mp(self) -> None: + """ + Load a corpus using multiprocessing + """ + if self.stopped is None: + self.stopped = Stopped() + begin_time = time.time() + manager = mp.Manager() + job_queue = manager.Queue() + return_queue = manager.Queue() + return_dict = manager.dict() + return_dict["decode_error_files"] = manager.list() + return_dict["textgrid_read_errors"] = manager.dict() + finished_adding = Stopped() + procs = [] + for _ in range(self.num_jobs): + p = CorpusProcessWorker( + job_queue, return_dict, return_queue, self.stopped, finished_adding + ) + procs.append(p) + p.start() + try: + for root, _, files in os.walk(self.corpus_directory, followlinks=True): + identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts( + files + ) + relative_path = root.replace(self.corpus_directory, "").lstrip("/").lstrip("\\") + + if self.stopped.stop_check(): + break + for file_name in identifiers: + if self.stopped.stop_check(): + break + wav_path = None + transcription_path = None + if file_name in lab_files: + lab_name = lab_files[file_name] + transcription_path = os.path.join(root, lab_name) + + elif file_name in textgrid_files: + tg_name = textgrid_files[file_name] + transcription_path = os.path.join(root, tg_name) + job_queue.put( + ( + file_name, + wav_path, + transcription_path, + relative_path, + self.speaker_characters, + self.construct_sanitize_function(), + ) + ) + + finished_adding.stop() + self.log_debug("Finished adding jobs!") + job_queue.join() + + self.log_debug("Waiting for workers to finish...") + for p in procs: + p.join() + + while True: + try: + file = return_queue.get(timeout=1) + if self.stopped.stop_check(): + continue + except Empty: + break + + self.add_file(file) + + if "error" in return_dict: + raise return_dict["error"][1] + + for k in ["decode_error_files", "textgrid_read_errors"]: + if hasattr(self, k): + if return_dict[k]: + self.log_info( + "There were some issues with files in the corpus. " + "Please look at the log file or run the validator for more information." + ) + self.log_debug(f"{k} showed {len(return_dict[k])} errors:") + if k == "textgrid_read_errors": + getattr(self, k).update(return_dict[k]) + for f, e in return_dict[k].items(): + self.log_debug(f"{f}: {e.error}") + else: + self.log_debug(", ".join(return_dict[k])) + setattr(self, k, return_dict[k]) + + except KeyboardInterrupt: + self.log_info("Detected ctrl-c, please wait a moment while we clean everything up...") + self.stopped.stop() + finished_adding.stop() + job_queue.join() + self.stopped.set_sigint_source() + while True: + try: + _ = return_queue.get(timeout=1) + if self.stopped.stop_check(): + continue + except Empty: + break + finally: + + if self.stopped.stop_check(): + self.log_info(f"Stopped parsing early ({time.time() - begin_time} seconds)") + if self.stopped.source(): + sys.exit(0) + else: + self.log_debug( + f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds" + ) + + def _load_corpus_from_source(self) -> None: + """ + Load a corpus without using multiprocessing + """ + begin_time = time.time() + self.stopped = False + + for root, _, files in os.walk(self.corpus_directory, followlinks=True): + identifiers, wav_files, lab_files, textgrid_files, other_audio_files = find_exts(files) + relative_path = root.replace(self.corpus_directory, "").lstrip("/").lstrip("\\") + if self.stopped: + return + for file_name in identifiers: + + wav_path = None + transcription_path = None + if file_name in lab_files: + lab_name = lab_files[file_name] + transcription_path = os.path.join(root, lab_name) + elif file_name in textgrid_files: + tg_name = textgrid_files[file_name] + transcription_path = os.path.join(root, tg_name) + + try: + file = parse_file( + file_name, + wav_path, + transcription_path, + relative_path, + self.speaker_characters, + self.construct_sanitize_function(), + ) + self.add_file(file) + except TextParseError as e: + self.decode_error_files.append(e) + except TextGridParseError as e: + self.textgrid_read_errors[e.file_name] = e + if self.decode_error_files or self.textgrid_read_errors: + self.log_info( + "There were some issues with files in the corpus. " + "Please look at the log file or run the validator for more information." + ) + if self.decode_error_files: + self.log_debug( + f"There were {len(self.decode_error_files)} errors decoding text files:" + ) + self.log_debug(", ".join(self.decode_error_files)) + if self.textgrid_read_errors: + self.log_debug( + f"There were {len(self.textgrid_read_errors)} errors decoding reading TextGrid files:" + ) + for f, e in self.textgrid_read_errors.items(): + self.log_debug(f"{f}: {e.error}") + + self.log_debug(f"Parsed corpus directory in {time.time()-begin_time} seconds") + + +class DictionaryTextCorpusMixin(TextCorpusMixin, MultispeakerDictionaryMixin): + """ + Abstract mixin class for processing text corpora with pronunciation dictionaries. + + This is primarily useful for training language models, as you can treat words in the language model as OOV if they + aren't in your pronunciation dictionary + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin` + For corpus parsing parameters + :class:`~montreal_forced_aligner.dictionary.multispeaker.MultispeakerDictionaryMixin` + For dictionary parsing parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_corpus(self) -> None: + """ + Load the corpus + """ + self.dictionary_setup() + self._load_corpus() + self.set_lexicon_word_set(self.corpus_word_set) + self.write_lexicon_information() + + for speaker in self.speakers.values(): + speaker.set_dictionary(self.get_dictionary(speaker.name)) + self.initialize_jobs() + self.write_corpus_information() + self.create_corpus_split() + + +class TextCorpus(DictionaryTextCorpusMixin, MfaWorker, TemporaryDirectoryMixin): + """ + Standalone class for working with text corpora and pronunciation dictionaries + + Most MFA functionality will use the :class:`~montreal_forced_aligner.corpus.text_corpus.DictionaryTextCorpusMixin` class rather than this class. + + Parameters + ---------- + num_jobs: int + Number of jobs to use when loading the corpus + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.text_corpus.DictionaryTextCorpusMixin` + For dictionary and corpus parsing parameters + :class:`~montreal_forced_aligner.abc.MfaWorker` + For MFA processing parameters + :class:`~montreal_forced_aligner.abc.TemporaryDirectoryMixin` + For temporary directory parameters + """ + + def __init__(self, num_jobs=3, **kwargs): + super().__init__(**kwargs) + self.num_jobs = num_jobs + + def load_corpus(self) -> None: + """Load the corpus""" + self._load_corpus() + + @property + def identifier(self) -> str: + """Identifier for the corpus""" + return self.data_source_identifier + + @property + def output_directory(self) -> str: + """Root temporary directory to store all corpus and dictionary files""" + return os.path.join(self.temporary_directory, self.identifier) + + @property + def working_directory(self) -> str: + """Working directory""" + return self.output_directory + + def log_debug(self, message: str) -> None: + """ + Print a debug message + + Parameters + ---------- + message: str + Debug message to log + """ + print(message) + + def log_error(self, message: str) -> None: + """ + Print an error message + + Parameters + ---------- + message: str + Error message to log + """ + print(message) + + def log_info(self, message: str) -> None: + """ + Print an info message + + Parameters + ---------- + message: str + Info message to log + """ + print(message) + + def log_warning(self, message: str) -> None: + """ + Print a warning message + + Parameters + ---------- + message: str + Warning message to log + """ + print(message) diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index 80f6380c..2903e936 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -4,10 +4,11 @@ """ from dataclasses import dataclass -from typing import List from praatio.utilities.constants import Interval +from .exceptions import CtmError + __all__ = ["CtmInterval"] @@ -15,23 +16,27 @@ class CtmInterval: """ Data class for intervals derived from CTM files - - Attributes - ---------- - begin: float - Start time of interval - end: float - End time of interval - label: str - Text of interval - utterance: str - Utterance ID that the interval belongs to """ begin: float + """Start time of interval""" end: float + """End time of interval""" label: str + """Text of interval""" utterance: str + """Utterance ID that the interval belongs to""" + + def __post_init__(self): + """ + Check on data validity + + Raises + ------ + CtmError + If begin or end are not valid""" + if self.end < -1 or self.begin == 1000000: + raise CtmError(self) def shift_times(self, offset: float): """ @@ -41,7 +46,6 @@ def shift_times(self, offset: float): ---------- offset: float Offset to add to the interval's begin and end - """ self.begin += offset self.end += offset @@ -52,10 +56,9 @@ def to_tg_interval(self) -> Interval: Returns ------- - :class:`~praatio.utilities.constants.Interval` + :class:`praatio.utilities.constants.Interval` Derived PraatIO Interval """ + if self.end < -1 or self.begin == 1000000: + raise CtmError(self) return Interval(self.begin, self.end, self.label) - - -CtmType = List[CtmInterval] diff --git a/montreal_forced_aligner/dictionary/__init__.py b/montreal_forced_aligner/dictionary/__init__.py index 5796eeff..d721636e 100644 --- a/montreal_forced_aligner/dictionary/__init__.py +++ b/montreal_forced_aligner/dictionary/__init__.py @@ -4,18 +4,26 @@ """ -from .base_dictionary import PronunciationDictionary -from .data import DictionaryData -from .multispeaker import MultispeakerDictionary +from montreal_forced_aligner.dictionary.mixins import DictionaryMixin, SanitizeFunction +from montreal_forced_aligner.dictionary.multispeaker import ( + MultispeakerDictionary, + MultispeakerDictionaryMixin, +) +from montreal_forced_aligner.dictionary.pronunciation import ( + DictionaryData, + PronunciationDictionary, + PronunciationDictionaryMixin, +) __all__ = [ - "base_dictionary", + "pronunciation", "multispeaker", - "data", + "mixins", + "DictionaryData", + "DictionaryMixin", + "SanitizeFunction", "MultispeakerDictionary", + "MultispeakerDictionaryMixin", "PronunciationDictionary", - "DictionaryData", + "PronunciationDictionaryMixin", ] -MultispeakerDictionary.__module__ = "montreal_forced_aligner.dictionary" -PronunciationDictionary.__module__ = "montreal_forced_aligner.dictionary" -DictionaryData.__module__ = "montreal_forced_aligner.dictionary" diff --git a/montreal_forced_aligner/dictionary/data.py b/montreal_forced_aligner/dictionary/data.py deleted file mode 100644 index ebb1d503..00000000 --- a/montreal_forced_aligner/dictionary/data.py +++ /dev/null @@ -1,261 +0,0 @@ -"""Pronunciation dictionaries for use in alignment and transcription""" - -from __future__ import annotations - -import re -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional - -from ..data import CtmInterval - -if TYPE_CHECKING: - IpaType = Optional[List[str]] - PunctuationType = Optional[str] - from ..abc import DictionaryEntryType, MappingType, ReversedMappingType, WordsType - from ..config.dictionary_config import DictionaryConfig - from ..data import CtmType - -__all__ = [ - "DictionaryData", -] - - -@dataclass -class DictionaryData: - """ - Information required for parsing Kaldi-internal ids to text - """ - - dictionary_config: DictionaryConfig - words_mapping: MappingType - reversed_words_mapping: ReversedMappingType - reversed_phone_mapping: ReversedMappingType - words: WordsType - - @property - def oov_int(self): - return self.words_mapping[self.dictionary_config.oov_word] - - def split_clitics( - self, - item: str, - ) -> List[str]: - """ - Split a word into subwords based on dictionary information - - Parameters - ---------- - item: str - Word to split - - Returns - ------- - List[str] - List of subwords - """ - if item in self.words: - return [item] - if any(x in item for x in self.dictionary_config.compound_markers): - s = re.split(rf"[{''.join(self.dictionary_config.compound_markers)}]", item) - if any(x in item for x in self.dictionary_config.clitic_markers): - new_s = [] - for seg in s: - if any(x in seg for x in self.dictionary_config.clitic_markers): - new_s.extend(self.split_clitics(seg)) - else: - new_s.append(seg) - s = new_s - return s - if any( - x in item and not item.endswith(x) and not item.startswith(x) - for x in self.dictionary_config.clitic_markers - ): - initial, final = re.split( - rf"[{''.join(self.dictionary_config.clitic_markers)}]", item, maxsplit=1 - ) - if any(x in final for x in self.dictionary_config.clitic_markers): - final = self.split_clitics(final) - else: - final = [final] - for clitic in self.dictionary_config.clitic_markers: - if initial + clitic in self.dictionary_config.clitic_set: - return [initial + clitic] + final - elif clitic + final[0] in self.dictionary_config.clitic_set: - final[0] = clitic + final[0] - return [initial] + final - return [item] - - def lookup( - self, - item: str, - ) -> List[str]: - """ - Look up a word and return the list of sub words if necessary - taking into account clitic and compound markers - - Parameters - ---------- - item: str - Word to look up - - Returns - ------- - List[str] - List of subwords that are in the dictionary - """ - - if item in self.words: - return [item] - sanitized = self.dictionary_config.sanitize(item) - if sanitized in self.words: - return [sanitized] - split = self.split_clitics(sanitized) - oov_count = sum(1 for x in split if x not in self.words) - - if oov_count < len( - split - ): # Only returned split item if it gains us any transcribed speech - return split - return [sanitized] - - def to_int( - self, - item: str, - ) -> List[int]: - """ - Convert a given word into integer IDs - - Parameters - ---------- - item: str - Word to look up - - Returns - ------- - List[int] - List of integer IDs corresponding to each subword - """ - if item == "": - return [] - sanitized = self.lookup(item) - text_int = [] - for item in sanitized: - if not item: - continue - if item not in self.words_mapping: - text_int.append(self.oov_int) - else: - text_int.append(self.words_mapping[item]) - return text_int - - def check_word(self, item: str) -> bool: - """ - Check whether a word is in the dictionary, takes into account sanitization and - clitic and compound markers - - Parameters - ---------- - item: str - Word to check - - Returns - ------- - bool - True if the look up would not result in an OOV item - """ - if item == "": - return False - if item in self.words: - return True - sanitized = self.dictionary_config.sanitize(item) - if sanitized in self.words: - return True - - sanitized = self.split_clitics(sanitized) - if all(s in self.words for s in sanitized): - return True - return False - - def map_to_original_pronunciation( - self, phones: CtmType, subpronunciations: List[DictionaryEntryType] - ) -> CtmType: - """ - Convert phone transcriptions from multilingual IPA mode to their original IPA transcription - - Parameters - ---------- - phones: List[CtmInterval] - List of aligned phones - subpronunciations: List[DictionaryEntryType] - Pronunciations of each sub word to reconstruct the transcriptions - - Returns - ------- - List[CtmInterval] - Intervals with their original IPA pronunciation rather than the internal simplified form - """ - transcription = tuple(x.label for x in phones) - new_phones = [] - mapping_ind = 0 - transcription_ind = 0 - for pronunciations in subpronunciations: - pron = None - if mapping_ind >= len(phones): - break - for p in pronunciations: - if ( - "original_pronunciation" in p - and transcription == p["pronunciation"] == p["original_pronunciation"] - ) or (transcription == p["pronunciation"] and "original_pronunciation" not in p): - new_phones.extend(phones) - mapping_ind += len(phones) - break - if ( - p["pronunciation"] - == transcription[ - transcription_ind : transcription_ind + len(p["pronunciation"]) - ] - and pron is None - ): - pron = p - if mapping_ind >= len(phones): - break - if not pron: - new_phones.extend(phones) - mapping_ind += len(phones) - break - to_extend = phones[transcription_ind : transcription_ind + len(pron["pronunciation"])] - transcription_ind += len(pron["pronunciation"]) - p = pron - if ( - "original_pronunciation" not in p - or p["pronunciation"] == p["original_pronunciation"] - ): - new_phones.extend(to_extend) - mapping_ind += len(to_extend) - break - for pi in p["original_pronunciation"]: - if pi == phones[mapping_ind].label: - new_phones.append(phones[mapping_ind]) - else: - modded_phone = pi - new_p = phones[mapping_ind].label - for diacritic in self.dictionary_config.strip_diacritics: - modded_phone = modded_phone.replace(diacritic, "") - if modded_phone == new_p: - phones[mapping_ind].label = pi - new_phones.append(phones[mapping_ind]) - elif mapping_ind != len(phones) - 1: - new_p = phones[mapping_ind].label + phones[mapping_ind + 1].label - if modded_phone == new_p: - new_phones.append( - CtmInterval( - phones[mapping_ind].begin, - phones[mapping_ind + 1].end, - new_p, - phones[mapping_ind].utterance, - ) - ) - mapping_ind += 1 - mapping_ind += 1 - return new_phones diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py new file mode 100644 index 00000000..2a535f8f --- /dev/null +++ b/montreal_forced_aligner/dictionary/mixins.py @@ -0,0 +1,1039 @@ +"""Mixins for dictionary parsing capabilities""" + +from __future__ import annotations + +import abc +import os +import re +from collections import Counter +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +from montreal_forced_aligner.abc import TemporaryDirectoryMixin +from montreal_forced_aligner.data import CtmInterval + +if TYPE_CHECKING: + from montreal_forced_aligner.abc import ( + DictionaryEntryType, + MappingType, + MetaDict, + ReversedMappingType, + WordsType, + ) + +DEFAULT_PUNCTUATION = list(r'、。।,@<>"(),.:;¿?¡!\\&%#*~【】,…‥「」『』〝〟″⟨⟩♪・‹›«»~′$+=‘') + +DEFAULT_CLITIC_MARKERS = list("'’") +DEFAULT_COMPOUND_MARKERS = list("-/") +DEFAULT_STRIP_DIACRITICS = ["ː", "ˑ", "̩", "̆", "̑", "̯", "͡", "‿", "͜"] +DEFAULT_DIGRAPHS = ["[dt][szʒʃʐʑʂɕç]", "[aoɔe][ʊɪ]"] +DEFAULT_BRACKETS = [("[", "]"), ("{", "}"), ("<", ">"), ("(", ")")] + +__all__ = ["SanitizeFunction", "DictionaryMixin"] + + +class SanitizeFunction: + """ + Class for functions that sanitize text and strip punctuation + + Parameters + ---------- + punctuation: list[str] + List of characters to treat as punctuation + clitic_markers: list[str] + Characters that mark clitics + compound_markers: list[str] + Characters that mark compound words + brackets: list[tuple[str, str]] + List of bracket sets to not strip from the ends of words + """ + + def __init__( + self, + punctuation: list[str], + clitic_markers: list[str], + compound_markers: list[str], + brackets: list[tuple[str, str]], + ): + self.punctuation = punctuation + self.clitic_markers = clitic_markers + self.compound_markers = compound_markers + self.brackets = brackets + + def __call__(self, item): + """ + Sanitize an item according to punctuation and clitic markers + + Parameters + ---------- + item: str + Word to sanitize + + Returns + ------- + str + Sanitized form + """ + for c in self.clitic_markers: + item = item.replace(c, self.clitic_markers[0]) + if not item: + return item + for b in self.brackets: + if re.match(rf"^{re.escape(b[0])}.*{re.escape(b[1])}$", item): + return item + if self.punctuation: + item = re.sub(rf"^[{re.escape(''.join(self.punctuation))}]+", "", item) + item = re.sub(rf"[{re.escape(''.join(self.punctuation))}]+$", "", item) + return item + + +class SplitWordsFunction: + """ + Class for functions that splits words that have compound and clitic markers + + Parameters + ---------- + clitic_markers: list[str] + Characters that mark clitics + compound_markers: list[str] + Characters that mark compound words + """ + + def __init__( + self, + punctuation: list[str], + clitic_markers: list[str], + compound_markers: list[str], + brackets: list[tuple[str, str]], + clitic_set: set[str], + word_set: Optional[set[str]] = None, + ): + self.punctuation = punctuation + self.clitic_markers = clitic_markers + self.compound_markers = compound_markers + self.brackets = brackets + self.sanitize_function = SanitizeFunction( + punctuation, clitic_markers, compound_markers, brackets + ) + self.clitic_set = clitic_set + if not word_set: + word_set = None + self.word_set = word_set + self.compound_pattern = re.compile(rf"[{re.escape(''.join(self.compound_markers))}]") + initial_clitics = sorted( + x for x in self.clitic_set if any(x.endswith(y) for y in self.clitic_markers) + ) + final_clitics = sorted( + x for x in self.clitic_set if any(x.startswith(y) for y in self.clitic_markers) + ) + optional_initial_groups = f"({'|'.join(initial_clitics)})?" * 4 + optional_final_groups = f"({'|'.join(final_clitics)})?" * 4 + if initial_clitics and final_clitics: + self.clitic_pattern = re.compile( + rf"^(?:(?:{optional_initial_groups}(.+?))|(?:(.+?){optional_final_groups}))$" + ) + elif initial_clitics: + self.clitic_pattern = re.compile(rf"^(?:(?:{optional_initial_groups}(.+?))|(.+))$") + elif final_clitics: + self.clitic_pattern = re.compile(rf"^(?:(.+)|(?:(.+?){optional_final_groups}))$") + else: + self.clitic_pattern = None + + def split_clitics( + self, + item: str, + ) -> list[str]: + """ + Split a word into subwords based on dictionary information + + Parameters + ---------- + item: str + Word to split + + Returns + ------- + list[str] + List of subwords + """ + if self.word_set is not None and item in self.word_set: + return [item] + split = [] + s = re.split(self.compound_pattern, item) + for seg in s: + if self.clitic_pattern is None: + split.append(seg) + continue + + m = re.match(self.clitic_pattern, seg) + if not m: + split.append(seg) + continue + for g in m.groups(): + if g is None: + continue + split.append(g) + return split + + def __call__( + self, + item: str, + ) -> list[str]: + """ + Return the list of sub words if necessary + taking into account clitic and compound markers + + Parameters + ---------- + item: str + Word to look up + + Returns + ------- + list[str] + List of subwords that are in the dictionary + """ + if self.word_set is not None and item in self.word_set: + return [item] + sanitized = self.sanitize_function(item) + if self.word_set is not None and sanitized in self.word_set: + return [sanitized] + split = self.split_clitics(sanitized) + if self.word_set is None: + return split + oov_count = sum(1 for x in split if x not in self.word_set) + if oov_count < len( + split + ): # Only returned split item if it gains us any transcribed speech + return split + return [sanitized] + + +class DictionaryMixin: + """ + Abstract class for MFA classes that use acoustic models + + Parameters + ---------- + oov_code : str + What to label words not in the dictionary, defaults to ``''`` + position_dependent_phones : bool + Specifies whether phones should be represented as dependent on their + position in the word (beginning, middle or end), defaults to True + num_silence_states : int + Number of states to use for silence phones, defaults to 5 + num_non_silence_states : int + Number of states to use for non-silence phones, defaults to 3 + shared_silence_phones : bool + Specify whether to share states across all silence phones, defaults + to True + silence_probability : float + Probability of optional silences following words, defaults to 0.5 + punctuation: str, optional + Punctuation to use when parsing text + clitic_markers: str, optional + Clitic markers to use when parsing text + compound_markers: str, optional + Compound markers to use when parsing text + multilingual_ipa: bool + Flag for multilingual IPA mode, defaults to False + strip_diacritics: list[str], optional + Diacritics to strip in multilingual IPA mode + digraphs: list[str], optional + Digraphs to split up in multilingual IPA mode + brackets: list[tuple[str, str], optional + Character tuples to treat as full brackets around words + clitic_set: set[str] + Set of clitic words + disambiguation_symbols: set[str] + Set of disambiguation symbols + max_disambiguation_symbol: int + Maximum number of disambiguation symbols required, defaults to 0 + """ + + positions: list[str] = ["_B", "_E", "_I", "_S"] + + def __init__( + self, + oov_word: str = "", + silence_word: str = "!sil", + nonoptional_silence_phone: str = "sil", + optional_silence_phone: str = "sp", + oov_phone: str = "spn", + other_noise_phone: str = "spn", + position_dependent_phones: bool = True, + num_silence_states: int = 5, + num_non_silence_states: int = 3, + shared_silence_phones: bool = True, + silence_probability: float = 0.5, + punctuation: list[str] = None, + clitic_markers: list[str] = None, + compound_markers: list[str] = None, + multilingual_ipa: bool = False, + strip_diacritics: list[str] = None, + digraphs: list[str] = None, + brackets: list[tuple[str, str]] = None, + non_silence_phones: set[str] = None, + disambiguation_symbols: set[str] = None, + clitic_set: set[str] = None, + max_disambiguation_symbol: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.strip_diacritics = DEFAULT_STRIP_DIACRITICS + self.digraphs = DEFAULT_DIGRAPHS + self.punctuation = DEFAULT_PUNCTUATION + self.clitic_markers = DEFAULT_CLITIC_MARKERS + self.compound_markers = DEFAULT_COMPOUND_MARKERS + self.brackets = DEFAULT_BRACKETS + if strip_diacritics is not None: + self.strip_diacritics = strip_diacritics + if digraphs is not None: + self.digraphs = digraphs + if punctuation is not None: + self.punctuation = punctuation + if clitic_markers is not None: + self.clitic_markers = clitic_markers + if compound_markers is not None: + self.compound_markers = compound_markers + if brackets is not None: + self.brackets = brackets + + self.multilingual_ipa = multilingual_ipa + self.num_silence_states = num_silence_states + self.num_non_silence_states = num_non_silence_states + self.shared_silence_phones = shared_silence_phones + self.silence_probability = silence_probability + self.oov_word = oov_word + self.silence_word = silence_word + self.position_dependent_phones = position_dependent_phones + self.optional_silence_phone = optional_silence_phone + self.nonoptional_silence_phone = nonoptional_silence_phone + self.oov_phone = oov_phone + self.oovs_found = Counter() + self.other_noise_phone = other_noise_phone + if non_silence_phones is None: + non_silence_phones = set() + self.non_silence_phones = non_silence_phones + self.max_disambiguation_symbol = max_disambiguation_symbol + if disambiguation_symbols is None: + disambiguation_symbols = set() + self.disambiguation_symbols = disambiguation_symbols + if clitic_set is None: + clitic_set = set() + self.clitic_set = clitic_set + + @property + def dictionary_options(self) -> MetaDict: + """Dictionary options""" + return { + "strip_diacritics": self.strip_diacritics, + "digraphs": self.digraphs, + "punctuation": self.punctuation, + "clitic_markers": self.clitic_markers, + "clitic_set": self.clitic_set, + "compound_markers": self.compound_markers, + "brackets": self.brackets, + "multilingual_ipa": self.multilingual_ipa, + "num_silence_states": self.num_silence_states, + "num_non_silence_states": self.num_non_silence_states, + "shared_silence_phones": self.shared_silence_phones, + "silence_probability": self.silence_probability, + "oov_word": self.oov_word, + "silence_word": self.silence_word, + "position_dependent_phones": self.position_dependent_phones, + "optional_silence_phone": self.optional_silence_phone, + "nonoptional_silence_phone": self.nonoptional_silence_phone, + "oov_phone": self.oov_phone, + "other_noise_phone": self.other_noise_phone, + "non_silence_phones": self.non_silence_phones, + "max_disambiguation_symbol": self.max_disambiguation_symbol, + "disambiguation_symbols": self.disambiguation_symbols, + } + + @property + def silence_phones(self): + """Silence phones""" + return { + self.oov_phone, + self.optional_silence_phone, + self.nonoptional_silence_phone, + self.other_noise_phone, + } + + @property + def specials_set(self): + """Special words, like the ``oov_word`` ``silence_word``, ````, ````, and ````""" + return {self.oov_word, self.silence_word, "", "", ""} + + @property + def phone_mapping(self) -> dict[str, int]: + """Mapping of phones to integer IDs""" + phone_mapping = {} + i = 0 + phone_mapping[""] = i + if self.position_dependent_phones: + for p in self.positional_silence_phones: + i += 1 + phone_mapping[p] = i + for p in self.positional_non_silence_phones: + i += 1 + phone_mapping[p] = i + else: + for p in sorted(self.silence_phones): + i += 1 + phone_mapping[p] = i + for p in sorted(self.non_silence_phones): + i += 1 + phone_mapping[p] = i + i = max(phone_mapping.values()) + for x in range(self.max_disambiguation_symbol + 2): + p = f"#{x}" + self.disambiguation_symbols.add(p) + i += 1 + phone_mapping[p] = i + return phone_mapping + + @property + def positional_silence_phones(self) -> list[str]: + """ + List of silence phones with positions + """ + silence_phones = [] + for p in sorted(self.silence_phones): + silence_phones.append(p) + for pos in self.positions: + silence_phones.append(p + pos) + return silence_phones + + @property + def positional_non_silence_phones(self) -> list[str]: + """ + List of non-silence phones with positions + """ + non_silence_phones = [] + for p in sorted(self.non_silence_phones): + for pos in self.positions: + non_silence_phones.append(p + pos) + return non_silence_phones + + @property + def kaldi_silence_phones(self): + """Silence phones in Kaldi format""" + if self.position_dependent_phones: + return self.positional_silence_phones + return sorted(self.silence_phones) + + def save_oovs_found(self, directory: str) -> None: + """ + Save all out of vocabulary items to a file in the specified directory + + Parameters + ---------- + directory : str + Path to directory to save ``oovs_found.txt`` + """ + with open(os.path.join(directory, "oovs_found.txt"), "w", encoding="utf8") as f, open( + os.path.join(directory, "oov_counts.txt"), "w", encoding="utf8" + ) as cf: + for oov in sorted(self.oovs_found.keys(), key=lambda x: (-self.oovs_found[x], x)): + f.write(oov + "\n") + cf.write(f"{oov}\t{self.oovs_found[oov]}\n") + + @property + def kaldi_non_silence_phones(self): + """Non silence phones in Kaldi format""" + if self.position_dependent_phones: + return self.positional_non_silence_phones + return sorted(self.non_silence_phones) + + @property + def optional_silence_csl(self) -> str: + """ + Phone ID of the optional silence phone + """ + return str(self.phone_mapping[self.optional_silence_phone]) + + @property + def silence_csl(self) -> str: + """ + A colon-separated string of silence phone ids + """ + return ":".join(map(str, (self.phone_mapping[x] for x in self.kaldi_silence_phones))) + + @property + def phones(self) -> set: + """ + The set of all phones (silence and non-silence) + """ + return self.silence_phones | self.non_silence_phones + + def check_bracketed(self, word: str) -> bool: + """ + Checks whether a given string is surrounded by brackets. + + Parameters + ---------- + word : str + Text to check for final brackets + + Returns + ------- + bool + True if the word is fully bracketed, false otherwise + """ + for b in self.brackets: + if re.match(rf"^{re.escape(b[0])}.*{re.escape(b[1])}$", word): + return True + return False + + def construct_sanitize_function(self) -> SanitizeFunction: + """ + Construct a :class:`~montreal_forced_aligner.dictionary.mixins.SanitizeFunction` to use in multiprocessing jobs + + Returns + ------- + :class:`~montreal_forced_aligner.dictionary.mixins.SanitizeFunction` + Function for sanitizing text + """ + f = SanitizeFunction( + self.punctuation, self.clitic_markers, self.compound_markers, self.brackets + ) + + return f + + def sanitize(self, item: str) -> str: + """ + Sanitize an item according to punctuation and clitic markers + + Parameters + ---------- + item: str + Word to sanitize + + Returns + ------- + str + Sanitized form + """ + return self.construct_sanitize_function()(item) + + def parse_ipa(self, transcription: list[str]) -> tuple[str, ...]: + """ + Parse a transcription in a multilingual IPA format (strips out diacritics and splits digraphs). + + Parameters + ---------- + transcription: list[str] + Transcription to parse + + Returns + ------- + tuple[str, ...] + Parsed transcription + """ + new_transcription = [] + for t in transcription: + new_t = t + for d in self.strip_diacritics: + new_t = new_t.replace(d, "") + if "g" in new_t: + new_t = new_t.replace("g", "ɡ") + + found = False + for digraph in self.digraphs: + if re.match(rf"^{digraph}$", new_t): + found = True + if found: + new_transcription.extend(new_t) + continue + new_transcription.append(new_t) + return tuple(new_transcription) + + +class TemporaryDictionaryMixin(DictionaryMixin, TemporaryDirectoryMixin, metaclass=abc.ABCMeta): + def _write_word_boundaries(self) -> None: + """ + Write the word boundaries file to the temporary directory + """ + boundary_path = os.path.join( + self.dictionary_output_directory, "phones", "word_boundary.txt" + ) + boundary_int_path = os.path.join( + self.dictionary_output_directory, "phones", "word_boundary.int" + ) + with open(boundary_path, "w", encoding="utf8") as f, open( + boundary_int_path, "w", encoding="utf8" + ) as intf: + if self.position_dependent_phones: + for p in sorted(self.phone_mapping.keys(), key=lambda x: self.phone_mapping[x]): + if p == "" or p.startswith("#"): + continue + cat = "nonword" + if p.endswith("_B"): + cat = "begin" + elif p.endswith("_S"): + cat = "singleton" + elif p.endswith("_I"): + cat = "internal" + elif p.endswith("_E"): + cat = "end" + f.write(" ".join([p, cat]) + "\n") + intf.write(" ".join([str(self.phone_mapping[p]), cat]) + "\n") + + def _write_topo(self) -> None: + """ + Write the topo file to the temporary directory + """ + topo_template = " {cur_state} {cur_state} {cur_state} 0.75 {next_state} 0.25 " + topo_sil_template = " {cur_state} {cur_state} {transitions} " + topo_transition_template = " {} {}" + + sil_transp = 1 / (self.num_silence_states - 1) + initial_transition = [ + topo_transition_template.format(x, sil_transp) + for x in range(self.num_silence_states - 1) + ] + middle_transition = [ + topo_transition_template.format(x, sil_transp) + for x in range(1, self.num_silence_states) + ] + final_transition = [ + topo_transition_template.format(self.num_silence_states - 1, 0.75), + topo_transition_template.format(self.num_silence_states, 0.25), + ] + with open(self.topo_path, "w") as f: + f.write("\n") + f.write("\n") + f.write("\n") + phones = self.kaldi_non_silence_phones + f.write(f"{' '.join(str(self.phone_mapping[x]) for x in phones)}\n") + f.write("\n") + states = [ + topo_template.format(cur_state=x, next_state=x + 1) + for x in range(self.num_non_silence_states) + ] + f.write("\n".join(states)) + f.write(f"\n {self.num_non_silence_states} \n") + f.write("\n") + + f.write("\n") + f.write("\n") + + phones = self.kaldi_silence_phones + f.write(f"{' '.join(str(self.phone_mapping[x]) for x in phones)}\n") + f.write("\n") + states = [] + for i in range(self.num_silence_states): + if i == 0: + transition = " ".join(initial_transition) + elif i == self.num_silence_states - 1: + transition = " ".join(final_transition) + else: + transition = " ".join(middle_transition) + states.append(topo_sil_template.format(cur_state=i, transitions=transition)) + f.write("\n".join(states)) + f.write(f"\n {self.num_silence_states} \n") + f.write("\n") + f.write("\n") + + def _write_phone_sets(self) -> None: + """ + Write phone symbol sets to the temporary directory + """ + sharesplit = ["shared", "split"] + if not self.shared_silence_phones: + sil_sharesplit = ["not-shared", "not-split"] + else: + sil_sharesplit = sharesplit + + sets_file = os.path.join(self.dictionary_output_directory, "phones", "sets.txt") + roots_file = os.path.join(self.dictionary_output_directory, "phones", "roots.txt") + + sets_int_file = os.path.join(self.dictionary_output_directory, "phones", "sets.int") + roots_int_file = os.path.join(self.dictionary_output_directory, "phones", "roots.int") + + with open(sets_file, "w", encoding="utf8") as setf, open( + roots_file, "w", encoding="utf8" + ) as rootf, open(sets_int_file, "w", encoding="utf8") as setintf, open( + roots_int_file, "w", encoding="utf8" + ) as rootintf: + + # process silence phones + for i, sp in enumerate(self.silence_phones): + if self.position_dependent_phones: + mapped = [sp + x for x in [""] + self.positions] + else: + mapped = [sp] + setf.write(" ".join(mapped) + "\n") + setintf.write(" ".join(map(str, (self.phone_mapping[x] for x in mapped))) + "\n") + if i == 0: + line = sil_sharesplit + mapped + lineint = sil_sharesplit + [str(self.phone_mapping[x]) for x in mapped] + else: + line = sharesplit + mapped + lineint = sharesplit + [str(self.phone_mapping[x]) for x in mapped] + rootf.write(" ".join(line) + "\n") + rootintf.write(" ".join(lineint) + "\n") + + # process nonsilence phones + for nsp in sorted(self.non_silence_phones): + if self.position_dependent_phones: + mapped = [nsp + x for x in self.positions] + else: + mapped = [nsp] + setf.write(" ".join(mapped) + "\n") + setintf.write(" ".join(map(str, (self.phone_mapping[x] for x in mapped))) + "\n") + line = sharesplit + mapped + lineint = sharesplit + [str(self.phone_mapping[x]) for x in mapped] + rootf.write(" ".join(line) + "\n") + rootintf.write(" ".join(lineint) + "\n") + + @property + def phone_symbol_table_path(self): + """Path to file containing phone symbols and their integer IDs""" + return os.path.join(self.phones_dir, "phones.txt") + + def _write_phone_symbol_table(self) -> None: + """ + Write the phone mapping to the temporary directory + """ + with open(self.phone_symbol_table_path, "w", encoding="utf8") as f: + for p, i in sorted(self.phone_mapping.items(), key=lambda x: x[1]): + f.write(f"{p} {i}\n") + + @property + def disambiguation_symbols_txt_path(self): + """Path to the file containing phone disambiguation symbols""" + return os.path.join(self.phones_dir, "disambiguation_symbols.txt") + + @property + def disambiguation_symbols_int_path(self): + """Path to the file containing integer IDs for phone disambiguation symbols""" + return os.path.join(self.phones_dir, "disambiguation_symbols.int") + + @property + def phones_dir(self) -> str: + """Directory for storing phone information""" + return os.path.join(self.dictionary_output_directory, "phones") + + @property + def topo_path(self) -> str: + """Path to the dictionary's topology file""" + return os.path.join(self.phones_dir, "topo") + + def _write_extra_questions(self) -> None: + """ + Write extra questions symbols to the temporary directory + """ + phone_extra = os.path.join(self.phones_dir, "extra_questions.txt") + phone_extra_int = os.path.join(self.phones_dir, "extra_questions.int") + with open(phone_extra, "w", encoding="utf8") as outf, open( + phone_extra_int, "w", encoding="utf8" + ) as intf: + silences = self.kaldi_silence_phones + outf.write(" ".join(silences) + "\n") + intf.write(" ".join(str(self.phone_mapping[x]) for x in silences) + "\n") + + non_silences = self.kaldi_non_silence_phones + outf.write(" ".join(non_silences) + "\n") + intf.write(" ".join(str(self.phone_mapping[x]) for x in non_silences) + "\n") + if self.position_dependent_phones: + for p in self.positions: + line = [x + p for x in sorted(self.non_silence_phones)] + outf.write(" ".join(line) + "\n") + intf.write(" ".join(str(self.phone_mapping[x]) for x in line) + "\n") + for p in [""] + self.positions: + line = [x + p for x in sorted(self.silence_phones)] + outf.write(" ".join(line) + "\n") + intf.write(" ".join(str(self.phone_mapping[x]) for x in line) + "\n") + + def _write_disambig(self) -> None: + """ + Write disambiguation symbols to the temporary directory + """ + disambig = self.disambiguation_symbols_txt_path + disambig_int = self.disambiguation_symbols_int_path + with open(disambig, "w", encoding="utf8") as outf, open( + disambig_int, "w", encoding="utf8" + ) as intf: + for d in sorted(self.disambiguation_symbols, key=lambda x: self.phone_mapping[x]): + outf.write(f"{d}\n") + intf.write(f"{self.phone_mapping[d]}\n") + + def _write_phone_map_file(self) -> None: + """ + Write the phone map to the temporary directory + """ + outfile = os.path.join(self.phones_dir, "phone_map.txt") + with open(outfile, "w", encoding="utf8") as f: + for sp in self.silence_phones: + if self.position_dependent_phones: + new_phones = [sp + x for x in ["", ""] + self.positions] + else: + new_phones = [sp] + f.write(" ".join(new_phones) + "\n") + for nsp in self.non_silence_phones: + if self.position_dependent_phones: + new_phones = [nsp + x for x in [""] + self.positions] + else: + new_phones = [nsp] + f.write(" ".join(new_phones) + "\n") + + +@dataclass +class DictionaryData: + """ + Information required for parsing Kaldi-internal ids to text + + Attributes + ---------- + dictionary_options: dict[str, Any] + Options for the dictionary + sanitize_function: :class:`~montreal_forced_aligner.dictionary.mixins.SanitizeFunction` + Function to sanitize text + split_function: :class:`~montreal_forced_aligner.dictionary.mixins.SplitWordsFunction` + Function to split words into subwords + words_mapping: MappingType + Mapping from words to their integer IDs + reversed_words_mapping: ReversedMappingType + Mapping from integer IDs to words + words: WordsType + Words and their associated pronunciations + """ + + dictionary_options: MetaDict + sanitize_function: SanitizeFunction + split_function: SplitWordsFunction + words_mapping: MappingType + reversed_words_mapping: ReversedMappingType + words: WordsType + lookup_cache: dict[str, list[str]] + + @property + def oov_word(self) -> str: + """Out of vocabulary code""" + return self.dictionary_options["oov_word"] + + @property + def oov_int(self) -> int: + """Out of vocabulary integer ID""" + return self.words_mapping[self.oov_word] + + @property + def compound_markers(self) -> list[str]: + """Characters that separate compound words""" + return self.dictionary_options["compound_markers"] + + @property + def clitic_markers(self) -> list[str]: + """Characters that mark clitics""" + return self.dictionary_options["clitic_markers"] + + @property + def clitic_set(self) -> set[str]: + """Set of clitics""" + return self.dictionary_options["clitic_set"] + + @property + def punctuation(self) -> list[str]: + """Characters to treat as punctuation""" + return self.dictionary_options["punctuation"] + + @property + def strip_diacritics(self) -> list[str]: + """IPA diacritics to strip in multilingual IPA mode""" + return self.dictionary_options["strip_diacritics"] + + @property + def multilingual_ipa(self) -> bool: + """Flag for multilingual IPA mode""" + return self.dictionary_options["multilingual_ipa"] + + @property + def silence_phones(self) -> set[str]: + """Silence phones""" + return { + self.dictionary_options["oov_phone"], + self.dictionary_options["optional_silence_phone"], + self.dictionary_options["nonoptional_silence_phone"], + self.dictionary_options["other_noise_phone"], + } + + def lookup( + self, + item: str, + ) -> list[str]: + """ + Look up a word and return the list of sub words if necessary + taking into account clitic and compound markers + + Parameters + ---------- + item: str + Word to look up + + Returns + ------- + list[str] + List of subwords that are in the dictionary + """ + if item in self.lookup_cache: + return self.lookup_cache[item] + if item in self.words: + return [item] + sanitized = self.sanitize_function(item) + if sanitized in self.words: + self.lookup_cache[item] = [sanitized] + return [sanitized] + split = self.split_function(sanitized) + oov_count = sum(1 for x in split if x not in self.words) + if oov_count < len( + split + ): # Only returned split item if it gains us any transcribed speech + self.lookup_cache[item] = split + return split + self.lookup_cache[item] = [sanitized] + return [sanitized] + + def to_int( + self, + item: str, + ) -> list[int]: + """ + Convert a given word into integer IDs + + Parameters + ---------- + item: str + Word to look up + + Returns + ------- + list[int] + List of integer IDs corresponding to each subword + """ + if item == "": + return [] + sanitized = self.lookup(item) + text_int = [] + for item in sanitized: + if not item: + continue + if item not in self.words_mapping: + text_int.append(self.oov_int) + else: + text_int.append(self.words_mapping[item]) + return text_int + + def check_word(self, item: str) -> bool: + """ + Check whether a word is in the dictionary, takes into account sanitization and + clitic and compound markers + + Parameters + ---------- + item: str + Word to check + + Returns + ------- + bool + True if the look up would not result in an OOV item + """ + if item == "": + return False + if item in self.words: + return True + sanitized = self.sanitize_function(item) + if sanitized in self.words: + return True + + sanitized = self.split_function(sanitized) + if all(s in self.words for s in sanitized): + return True + return False + + def map_to_original_pronunciation( + self, phones: list[CtmInterval], subpronunciations: list[DictionaryEntryType] + ) -> list[CtmInterval]: + """ + Convert phone transcriptions from multilingual IPA mode to their original IPA transcription + + Parameters + ---------- + phones: list[CtmInterval] + List of aligned phones + subpronunciations: list[DictionaryEntryType] + Pronunciations of each sub word to reconstruct the transcriptions + + Returns + ------- + list[CtmInterval] + Intervals with their original IPA pronunciation rather than the internal simplified form + """ + transcription = tuple(x.label for x in phones) + new_phones = [] + mapping_ind = 0 + transcription_ind = 0 + for pronunciations in subpronunciations: + pron = None + if mapping_ind >= len(phones): + break + for p in pronunciations: + if ( + "original_pronunciation" in p + and transcription == p["pronunciation"] == p["original_pronunciation"] + ) or (transcription == p["pronunciation"] and "original_pronunciation" not in p): + new_phones.extend(phones) + mapping_ind += len(phones) + break + if ( + p["pronunciation"] + == transcription[ + transcription_ind : transcription_ind + len(p["pronunciation"]) + ] + and pron is None + ): + pron = p + if mapping_ind >= len(phones): + break + if not pron: + new_phones.extend(phones) + mapping_ind += len(phones) + break + to_extend = phones[transcription_ind : transcription_ind + len(pron["pronunciation"])] + transcription_ind += len(pron["pronunciation"]) + p = pron + if ( + "original_pronunciation" not in p + or p["pronunciation"] == p["original_pronunciation"] + ): + new_phones.extend(to_extend) + mapping_ind += len(to_extend) + break + for pi in p["original_pronunciation"]: + if pi == phones[mapping_ind].label: + new_phones.append(phones[mapping_ind]) + else: + modded_phone = pi + new_p = phones[mapping_ind].label + for diacritic in self.strip_diacritics: + modded_phone = modded_phone.replace(diacritic, "") + if modded_phone == new_p: + phones[mapping_ind].label = pi + new_phones.append(phones[mapping_ind]) + elif mapping_ind != len(phones) - 1: + new_p = phones[mapping_ind].label + phones[mapping_ind + 1].label + if modded_phone == new_p: + new_phones.append( + CtmInterval( + phones[mapping_ind].begin, + phones[mapping_ind + 1].end, + new_p, + phones[mapping_ind].utterance, + ) + ) + mapping_ind += 1 + mapping_ind += 1 + return new_phones diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index fcf2a327..b14d318b 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -2,124 +2,86 @@ from __future__ import annotations -import logging +import abc import os -from collections import Counter -from typing import TYPE_CHECKING, Collection, Dict, Optional, Union +from typing import TYPE_CHECKING, Collection, Optional, Union -from ..abc import Dictionary -from ..config.dictionary_config import DictionaryConfig -from ..models import DictionaryModel -from .base_dictionary import PronunciationDictionary +from montreal_forced_aligner.dictionary.mixins import TemporaryDictionaryMixin +from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionary +from montreal_forced_aligner.models import DictionaryModel if TYPE_CHECKING: + from montreal_forced_aligner.corpus.classes import Speaker - from ..corpus.classes import Speaker +__all__ = ["MultispeakerDictionaryMixin", "MultispeakerDictionary"] -__all__ = [ - "MultispeakerDictionary", -] - -class MultispeakerDictionary(Dictionary): +class MultispeakerDictionaryMixin(TemporaryDictionaryMixin, metaclass=abc.ABCMeta): """ - Class containing information about a pronunciation dictionary with different dictionaries per speaker + Mixin class containing information about a pronunciation dictionary with different dictionaries per speaker Parameters ---------- - dictionary_model : DictionaryModel - Multispeaker dictionary - output_directory : str - Path to a directory to store files for Kaldi - config: DictionaryConfig, optional - Configuration for generating lexicons - word_set : Collection[str], optional - Word set to limit output files - logger: :class:`~logging.Logger`, optional - Logger to output information to - """ + dictionary_path : str + Dictionary path + kwargs : kwargs + Extra parameters to passed to parent classes (see below) + + See Also + -------- + :class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin` + For dictionary parsing parameters + :class:`~montreal_forced_aligner.abc.TemporaryDirectoryMixin` + For temporary directory parameters - def __init__( - self, - dictionary_model: Union[DictionaryModel, str], - output_directory: str, - config: Optional[DictionaryConfig] = None, - word_set: Optional[Collection[str]] = None, - logger: Optional[logging.Logger] = None, - ): - if isinstance(dictionary_model, str): - dictionary_model = DictionaryModel(dictionary_model) - if config is None: - config = DictionaryConfig() - super().__init__(dictionary_model, config) - self.output_directory = os.path.join(output_directory, "dictionary") - os.makedirs(self.output_directory, exist_ok=True) - self.log_file = os.path.join(self.output_directory, "dictionary.log") - if logger is None: - self.logger = logging.getLogger("dictionary_setup") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - else: - self.logger = logger + Attributes + ---------- + dictionary_model: :class:`~montreal_forced_aligner.models.DictionaryModel` + Dictionary model + speaker_mapping: dict[str, str] + Mapping of speaker names to dictionary names + dictionary_mapping: dict[str, :class:`~montreal_forced_aligner.dictionary.pronunciation.PronunciationDictionary`] + Mapping of dictionary names to pronunciation dictionary + """ + + def __init__(self, dictionary_path: str = None, **kwargs): + super().__init__(**kwargs) + self.dictionary_model = DictionaryModel(dictionary_path) self.speaker_mapping = {} - self.dictionary_mapping = {} + self.dictionary_mapping: dict[str, PronunciationDictionary] = {} + def dictionary_setup(self): + """Setup the dictionary for processing""" for speaker, dictionary in self.dictionary_model.load_dictionary_paths().items(): self.speaker_mapping[speaker] = dictionary.name if dictionary.name not in self.dictionary_mapping: self.dictionary_mapping[dictionary.name] = PronunciationDictionary( - dictionary, - self.output_directory, - config, - word_set=word_set, - logger=self.logger, + dictionary_path=dictionary.path, + temporary_directory=self.dictionary_output_directory, + root_dictionary=self, + **self.dictionary_options, ) + self.non_silence_phones.update( + self.dictionary_mapping[dictionary.name].non_silence_phones + ) + for dictionary in self.dictionary_mapping.values(): + dictionary.non_silence_phones = self.non_silence_phones @property - def phones_dir(self): - return self.get_dictionary("default").phones_dir - - @property - def topo_path(self): - return os.path.join(self.get_dictionary("default").output_directory, "topo") + def name(self) -> str: + """Name of the dictionary""" + return self.dictionary_model.name - @property - def oovs_found(self) -> Counter[str, int]: - oovs = Counter() + def calculate_oovs_found(self) -> None: + """Sum the counts of oovs found in pronunciation dictionaries""" for dictionary in self.dictionary_mapping.values(): - oovs.update(dictionary.oovs_found) - return oovs - - def save_oovs_found(self, directory: str) -> None: - """ - Save all out of vocabulary items to a file in the specified directory - - Parameters - ---------- - directory : str - Path to directory to save ``oovs_found.txt`` - """ - with open(os.path.join(directory, "oovs_found.txt"), "w", encoding="utf8") as f, open( - os.path.join(directory, "oov_counts.txt"), "w", encoding="utf8" - ) as cf: - for oov in sorted(self.oovs_found.keys(), key=lambda x: (-self.oovs_found[x], x)): - f.write(oov + "\n") - cf.write(f"{oov}\t{self.oovs_found[oov]}\n") - - @property - def silences(self) -> set: - """ - Set of silence phones - """ - return self.config.silence_phones + self.oovs_found.update(dictionary.oovs_found) @property def default_dictionary(self) -> PronunciationDictionary: - """Default PronunciationDictionary""" + """Default pronunciation dictionary""" return self.get_dictionary("default") def get_dictionary_name(self, speaker: Union[str, Speaker]) -> str: @@ -134,7 +96,7 @@ def get_dictionary_name(self, speaker: Union[str, Speaker]) -> str: Returns ------- str - PronunciationDictionary name for the speaker + Dictionary name for the speaker """ if not isinstance(speaker, str): speaker = speaker.name @@ -154,11 +116,11 @@ def get_dictionary(self, speaker: Union[Speaker, str]) -> PronunciationDictionar Returns ------- :class:`~montreal_forced_aligner.dictionary.PronunciationDictionary` - PronunciationDictionary for the speaker + Pronunciation dictionary for the speaker """ return self.dictionary_mapping[self.get_dictionary_name(speaker)] - def write(self, write_disambiguation: Optional[bool] = False) -> None: + def write_lexicon_information(self, write_disambiguation: Optional[bool] = False) -> None: """ Write all child dictionaries to the temporary directory @@ -167,10 +129,22 @@ def write(self, write_disambiguation: Optional[bool] = False) -> None: write_disambiguation: bool, optional Flag to use disambiguation symbols in the output """ + os.makedirs(self.phones_dir, exist_ok=True) + for d in self.dictionary_mapping.values(): + d.generate_mappings() + if d.max_disambiguation_symbol > self.max_disambiguation_symbol: + self.max_disambiguation_symbol = d.max_disambiguation_symbol + self._write_word_boundaries() + self._write_phone_map_file() + self._write_phone_sets() + self._write_phone_symbol_table() + self._write_disambig() + self._write_topo() + self._write_extra_questions() for d in self.dictionary_mapping.values(): d.write(write_disambiguation) - def set_word_set(self, word_set: Collection[str]) -> None: + def set_lexicon_word_set(self, word_set: Collection[str]) -> None: """ Limit output to a subset of overall words @@ -180,11 +154,37 @@ def set_word_set(self, word_set: Collection[str]) -> None: Word set to limit generated files to """ for d in self.dictionary_mapping.values(): - d.set_word_set(word_set) + d.set_lexicon_word_set(word_set) @property - def output_paths(self) -> Dict[str, str]: + def output_paths(self) -> dict[str, str]: """ - Mapping of output directory for child dictionaries + Mapping of output directory for child directories """ - return {d.name: d.output_directory for d in self.dictionary_mapping.values()} + return {d.name: d.dictionary_output_directory for d in self.dictionary_mapping.values()} + + +class MultispeakerDictionary(MultispeakerDictionaryMixin): + """ + Class for processing multi- and single-speaker pronunciation dictionaries + + See Also + -------- + :class:`~montreal_forced_aligner.dictionary.multispeaker.MultispeakerDictionaryMixin` + For dictionary parsing parameters + """ + + @property + def data_source_identifier(self) -> str: + """Name of the dictionary""" + return f"{self.name}" + + @property + def identifier(self) -> str: + """Name of the dictionary""" + return f"{self.data_source_identifier}" + + @property + def output_directory(self) -> str: + """Root temporary directory to store all dictionary information""" + return os.path.join(self.temporary_directory, self.identifier) diff --git a/montreal_forced_aligner/dictionary/base_dictionary.py b/montreal_forced_aligner/dictionary/pronunciation.py similarity index 51% rename from montreal_forced_aligner/dictionary/base_dictionary.py rename to montreal_forced_aligner/dictionary/pronunciation.py index f09f9912..39d320ff 100644 --- a/montreal_forced_aligner/dictionary/base_dictionary.py +++ b/montreal_forced_aligner/dictionary/pronunciation.py @@ -2,110 +2,91 @@ from __future__ import annotations -import logging import math import os import subprocess import sys from collections import Counter, defaultdict -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, Optional if TYPE_CHECKING: - from ..abc import ReversedMappingType, DictionaryEntryType - -from ..abc import Dictionary -from ..config.dictionary_config import DictionaryConfig -from ..exceptions import DictionaryError, DictionaryFileError -from ..models import DictionaryModel -from ..utils import thirdparty_binary -from .data import DictionaryData + from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin + from montreal_forced_aligner.abc import ( + ReversedMappingType, + ) + +from montreal_forced_aligner.dictionary.mixins import ( + DictionaryData, + SplitWordsFunction, + TemporaryDictionaryMixin, +) +from montreal_forced_aligner.exceptions import DictionaryError, DictionaryFileError +from montreal_forced_aligner.models import DictionaryModel +from montreal_forced_aligner.utils import thirdparty_binary __all__ = [ - "PronunciationDictionary", + "PronunciationDictionaryMixin", ] -class PronunciationDictionary(Dictionary): +class PronunciationDictionaryMixin(TemporaryDictionaryMixin): """ - Class containing information about a pronunciation dictionary + Abstract mixin class containing information about a pronunciation dictionary Parameters ---------- - dictionary_model : :class:`~montreal_forced_aligner.models.DictionaryModel` - MFA Dictionary model - output_directory : str - Path to a directory to store files for Kaldi - config: DictionaryConfig - Configuration for generating lexicons - word_set : Collection[str], optional - Word set to limit output files - logger: :class:`~logging.Logger`, optional - Logger to output information to + dictionary_path : str + Path to pronunciation dictionary + root_dictionary : :class:`~montreal_forced_aligner.dictionary.mixins.TemporaryDictionaryMixin`, optional + Optional root dictionary to take phone information from + + See Also + -------- + :class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin` + For dictionary parsing parameters + :class:`~montreal_forced_aligner.abc.TemporaryDirectoryMixin` + For temporary directory parameters + + Attributes + ---------- + dictionary_model: DictionaryModel + Dictionary model to load + words: WordsType + Words mapped to their pronunciations + graphemes: set[str] + Set of graphemes in the dictionary + words_mapping: MappingType + Mapping of words to integer IDs + lexicon_word_set: set[str] + Word set to limit output of lexicon files """ - topo_template = " {cur_state} {cur_state} {cur_state} 0.75 {next_state} 0.25 " - topo_sil_template = " {cur_state} {cur_state} {transitions} " - topo_transition_template = " {} {}" - positions: List[str] = ["_B", "_E", "_I", "_S"] - - def __init__( - self, - dictionary_model: Union[DictionaryModel, str], - output_directory: str, - config: Optional[DictionaryConfig] = None, - word_set: Optional[Collection[str]] = None, - logger: Optional[logging.Logger] = None, - ): - if isinstance(dictionary_model, str): - dictionary_model = DictionaryModel(dictionary_model) - if config is None: - config = DictionaryConfig() - super().__init__(dictionary_model, config) - self.output_directory = os.path.join(output_directory, self.name) - os.makedirs(self.output_directory, exist_ok=True) - self.log_file = os.path.join(self.output_directory, f"{self.name}.log") - if logger is None: - self.logger = logging.getLogger("dictionary_setup") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - else: - self.logger = logger - self.oovs_found = Counter() - + def __init__(self, dictionary_path, root_dictionary=None, **kwargs): + super().__init__(**kwargs) + self.dictionary_model = DictionaryModel(dictionary_path) + self.root_dictionary = root_dictionary + os.makedirs(self.dictionary_output_directory, exist_ok=True) self.words = {} self.graphemes = set() - self.all_words = defaultdict(list) - self.words[self.config.silence_word] = [ - {"pronunciation": (self.config.nonoptional_silence_phone,), "probability": 1} - ] - self.words[self.config.oov_word] = [ - {"pronunciation": (self.config.oov_phone,), "probability": 1} + self.words[self.silence_word] = [ + {"pronunciation": (self.nonoptional_silence_phone,), "probability": 1} ] - - progress = f'Parsing dictionary "{self.name}"' - if self.dictionary_model.pronunciation_probabilities: - progress += " with pronunciation probabilities" - else: - progress += " without pronunciation probabilities" - if self.dictionary_model.silence_probabilities: - progress += " with silence probabilities" - else: - progress += " without silence probabilities" - self.logger.info(progress) + self.words[self.oov_word] = [{"pronunciation": (self.oov_phone,), "probability": 1}] + self.lookup_cache = {} + self.int_cache = {} + self.check_cache = {} with open(self.dictionary_model.path, "r", encoding="utf8") as inf: for i, line in enumerate(inf): line = line.strip() if not line: continue line = line.split() - word = self.config.sanitize(line.pop(0).lower()) + word = self.sanitize(line.pop(0).lower()) if not line: raise DictionaryError( f"Line {i} of {self.dictionary_model.path} does not have a pronunciation." ) - if word in [self.config.silence_word, self.config.oov_word]: + if word in [self.silence_word, self.oov_word]: continue self.graphemes.update(word) prob = 1 @@ -121,8 +102,8 @@ def __init__( right_sil_prob = None left_sil_prob = None left_nonsil_prob = None - if self.config.multilingual_ipa: - pron = self.config.parse_ipa(line) + if self.multilingual_ipa: + pron = self.parse_ipa(line) else: pron = tuple(line) pronunciation = { @@ -133,10 +114,10 @@ def __init__( "left_sil_prob": left_sil_prob, "left_nonsil_prob": left_nonsil_prob, } - if self.config.multilingual_ipa: + if self.multilingual_ipa: pronunciation["original_pronunciation"] = tuple(line) - if not any(x in self.config.silence_phones for x in pron): - self.config.non_silence_phones.update(pron) + if not any(x in self.silence_phones for x in pron): + self.non_silence_phones.update(pron) if word in self.words and pron in {x["pronunciation"] for x in self.words[word]}: continue if word not in self.words: @@ -144,41 +125,39 @@ def __init__( self.words[word].append(pronunciation) # test whether a word is a clitic is_clitic = False - for cm in self.config.clitic_markers: + for cm in self.clitic_markers: if word.startswith(cm) or word.endswith(cm): is_clitic = True if is_clitic: - self.config.clitic_set.add(word) + self.clitic_set.add(word) self.words_mapping = {} - if word_set is not None: - word_set = {y for x in word_set for y in self._lookup(x)} - word_set.add(self.config.silence_word) - word_set.add(self.config.oov_word) - self.word_set = word_set - if self.word_set is not None: - self.word_set = self.word_set | self.config.clitic_set + self._dictionary_data: Optional[DictionaryData] = None + self.lexicon_word_set = None if not self.graphemes: raise DictionaryFileError( f"No words were found in the dictionary path {self.dictionary_model.path}" ) + @property + def name(self) -> str: + """Name of the dictionary""" + return self.dictionary_model.name + def __hash__(self) -> Any: """Return the hash of a given dictionary""" return hash(self.dictionary_model.path) @property - def output_paths(self) -> Dict[str, str]: - """ - Mapping of output directory for this dictionary - """ - return {self.name: self.output_directory} + def dictionary_output_directory(self) -> str: + """Temporary directory to store all dictionary information""" + return os.path.join(self.temporary_directory, self.name) @property - def silences(self) -> Set[str]: + def silences(self) -> set[str]: """ Set of symbols that correspond to silence """ - return self.config.silence_phones + return self.silence_phones def data(self, word_set: Optional[Collection[str]] = None) -> DictionaryData: """ @@ -199,9 +178,9 @@ def word_check(word): """Check whether a word should be included in the output""" if word in word_set: return True - if word in self.config.clitic_set: + if word in self.clitic_set: return True - if word in self.config.specials_set: + if word in self.specials_set: return True return False @@ -216,38 +195,55 @@ def word_check(word): reversed_word_mapping = self.reversed_word_mapping words = self.words return DictionaryData( - self.config, + self.dictionary_options, + self.construct_sanitize_function(), + self.construct_split_words_function(), words_mapping, reversed_word_mapping, - self.reversed_phone_mapping, words, + self.lookup_cache, ) - def set_word_set(self, word_set: Collection[str]) -> None: + def construct_split_words_function(self) -> SplitWordsFunction: """ - Limit output to a subset of overall words + Construct a :class:`~montreal_forced_aligner.dictionary.mixins.SplitWordsFunction` to use in multiprocessing jobs + + Returns + ------- + :class:`~montreal_forced_aligner.dictionary.mixins.SplitWordsFunction` + Function for sanitizing text + """ + f = SplitWordsFunction( + self.punctuation, + self.clitic_markers, + self.compound_markers, + self.brackets, + self.clitic_set, + set(self.words.keys()), + ) + + return f + + def set_lexicon_word_set(self, word_set: Collection[str]) -> None: + """ + Limit lexicon output to a subset of overall words Parameters ---------- word_set: Collection[str] Word set to limit generated files to """ - word_set = {y for x in word_set for y in self._lookup(x)} - word_set.add(self.config.silence_word) - word_set.add(self.config.oov_word) - self.word_set = word_set | self.config.clitic_set - self.generate_mappings() + split_function = self.construct_split_words_function() + self.lexicon_word_set = {self.silence_word, self.oov_word} + self.lexicon_word_set.update(self.clitic_set) + for word in word_set: + self.lookup_cache[word] = split_function(word) + self.lexicon_word_set.update(self.lookup_cache[word]) - @property - def actual_words(self) -> Dict[str, "DictionaryEntryType"]: - """ - Mapping of words to integer IDs without Kaldi-internal words - """ - return { - k: v for k, v in self.words.items() if k not in self.config.specials_set and len(v) - } + self._dictionary_data = self.data(self.lexicon_word_set) + self.generate_mappings() - def split_clitics(self, item: str) -> List[str]: + def split_clitics(self, item: str) -> list[str]: """ Split a word into subwords based on clitic and compound markers @@ -258,10 +254,15 @@ def split_clitics(self, item: str) -> List[str]: Returns ------- - List[str] + list[str] List of subwords """ - return self.data().split_clitics(item) + + return self.construct_split_words_function()(item) + + def __bool__(self): + """Check that the dictionary contains words""" + return bool(self.words) def __len__(self) -> int: """Return the number of pronunciations across all words""" @@ -282,19 +283,15 @@ def exclude_for_alignment(self, word: str) -> bool: bool True if there is no word set on the dictionary, or if the word is in the given word set """ - if self.word_set is None: + if self.lexicon_word_set is None: return False - if word not in self.word_set and word not in self.config.clitic_set: + if word not in self.lexicon_word_set and word not in self.clitic_set: return True return False - @property - def phone_mapping(self) -> Dict[str, int]: - return self.config.phone_mapping - def generate_mappings(self) -> None: """ - Generate phone and word mappings from text to integer IDs + Generate word mappings from text to integer IDs """ self.words_mapping = {} i = 0 @@ -310,6 +307,7 @@ def generate_mappings(self) -> None: self.words_mapping[""] = i + 3 self.oovs_found = Counter() self.add_disambiguation() + self._dictionary_data = self.data() def add_disambiguation(self) -> None: """ @@ -343,19 +341,19 @@ def add_disambiguation(self) -> None: disambig = last_used[pron] p["disambiguation"] = disambig if last_used: - self.config.max_disambiguation_symbol = max( - self.config.max_disambiguation_symbol, max(last_used.values()) + self.max_disambiguation_symbol = max( + self.max_disambiguation_symbol, max(last_used.values()) ) - def create_utterance_fst(self, text: List[str], frequent_words: List[Tuple[str, int]]) -> str: + def create_utterance_fst(self, text: list[str], frequent_words: list[tuple[str, int]]) -> str: """ Create an FST for an utterance with frequent words as a unigram language model Parameters ---------- - text: List[str] + text: list[str] Text of the utterance - frequent_words: List[Tuple[str, int]] + frequent_words: list[tuple[str, int]] Frequent words to incorporate into the FST Returns ------- @@ -374,7 +372,7 @@ def create_utterance_fst(self, text: List[str], frequent_words: List[Tuple[str, fst_text += f"0 {-1 * math.log(1 / num_words)}\n" return fst_text - def to_int(self, item: str) -> List[int]: + def to_int(self, item: str) -> list[int]: """ Convert a given word into integer IDs @@ -385,12 +383,14 @@ def to_int(self, item: str) -> List[int]: Returns ------- - List[int] + list[int] List of integer IDs corresponding to each subword """ - return self.data().to_int(item) + if item not in self.int_cache: + self.int_cache[item] = self._dictionary_data.to_int(item) + return self.int_cache[item] - def _lookup(self, item: str) -> List[str]: + def _lookup(self, item: str) -> list[str]: """ Look up a word and return the list of sub words if necessary taking into account clitic and compound markers @@ -401,10 +401,15 @@ def _lookup(self, item: str) -> List[str]: Returns ------- - List[str] + list[str] List of subwords that are in the dictionary """ - return self.data().lookup(item) + if item not in self.lookup_cache: + if self._dictionary_data is not None: + self.lookup_cache[item] = self._dictionary_data.lookup(item) + else: + self.lookup_cache[item] = self.construct_split_words_function()(item) + return self.lookup_cache[item] def check_word(self, item: str) -> bool: """ @@ -421,7 +426,9 @@ def check_word(self, item: str) -> bool: bool True if the look up would not result in an OOV item """ - return self.data().check_word(item) + if item not in self.check_cache: + self.check_cache[item] = self._dictionary_data.check_word(item) + return self.check_cache[item] @property def reversed_word_mapping(self) -> ReversedMappingType: @@ -448,30 +455,44 @@ def oov_int(self) -> int: """ The integer id for out of vocabulary items """ - return self.words_mapping[self.config.oov_word] + return self.words_mapping[self.oov_word] @property def phones_dir(self) -> str: """ Directory to store information Kaldi needs about phones """ - return os.path.join(self.output_directory, "phones") + if self.root_dictionary is not None: + return self.root_dictionary.phones_dir + + return os.path.join(self.dictionary_output_directory, "phones") @property def words_symbol_path(self) -> str: """ Path of word to int mapping file for the dictionary """ - return os.path.join(self.output_directory, "words.txt") + return os.path.join(self.dictionary_output_directory, "words.txt") + + @property + def lexicon_fst_path(self) -> str: + """ + Path of disambiguated lexicon fst (L.fst) + """ + return os.path.join(self.dictionary_output_directory, "L.fst") @property - def disambig_path(self) -> str: + def lexicon_disambig_fst_path(self) -> str: """ Path of disambiguated lexicon fst (L.fst) """ - return os.path.join(self.output_directory, "L_disambig.fst") + return os.path.join(self.dictionary_output_directory, "L_disambig.fst") - def write(self, write_disambiguation: Optional[bool] = False) -> None: + def write( + self, + write_disambiguation: bool = False, + debug=False, + ) -> None: """ Write the files necessary for Kaldi @@ -479,18 +500,21 @@ def write(self, write_disambiguation: Optional[bool] = False) -> None: ---------- write_disambiguation: bool, optional Flag for including disambiguation information - """ - self.logger.info("Creating dictionary information...") - os.makedirs(self.phones_dir, exist_ok=True) - self.generate_mappings() + debug: bool, optional + Flag for whether to keep temporary files, defaults to False + """ + if self.root_dictionary is None: + self.generate_mappings() + os.makedirs(self.phones_dir, exist_ok=True) + self._write_word_boundaries() + self._write_phone_map_file() + self._write_phone_sets() + self._write_phone_symbol_table() + self._write_disambig() + self._write_topo() + self._write_extra_questions() + self._write_graphemes() - self._write_phone_map_file() - self._write_phone_sets() - self._write_phone_symbol_table() - self._write_disambig() - self._write_topo() - self._write_word_boundaries() - self._write_extra_questions() self._write_word_file() self._write_align_lexicon() if write_disambiguation: @@ -498,25 +522,23 @@ def write(self, write_disambiguation: Optional[bool] = False) -> None: else: self._write_basic_fst_text() self._write_fst_binary(write_disambiguation=write_disambiguation) - self.cleanup() + if not debug: + self.cleanup() def cleanup(self) -> None: """ Clean up temporary files in the output directory """ - if not self.config.debug: - if os.path.exists(os.path.join(self.output_directory, "temp.fst")): - os.remove(os.path.join(self.output_directory, "temp.fst")) - if os.path.exists(os.path.join(self.output_directory, "lexicon.text.fst")): - os.remove(os.path.join(self.output_directory, "lexicon.text.fst")) + if os.path.exists(os.path.join(self.dictionary_output_directory, "temp.fst")): + os.remove(os.path.join(self.dictionary_output_directory, "temp.fst")) + if os.path.exists(os.path.join(self.dictionary_output_directory, "lexicon.text.fst")): + os.remove(os.path.join(self.dictionary_output_directory, "lexicon.text.fst")) def _write_graphemes(self) -> None: """ Write graphemes to temporary directory """ - outfile = os.path.join(self.output_directory, "graphemes.txt") - if os.path.exists(outfile): - return + outfile = os.path.join(self.dictionary_output_directory, "graphemes.txt") with open(outfile, "w", encoding="utf8") as f: for char in sorted(self.graphemes): f.write(char + "\n") @@ -553,72 +575,11 @@ def export_lexicon( else: f.write(f"{w}\t{phones}\n") - def _write_phone_map_file(self) -> None: - """ - Write the phone map to the temporary directory - """ - outfile = os.path.join(self.output_directory, "phone_map.txt") - if os.path.exists(outfile): - return - with open(outfile, "w", encoding="utf8") as f: - for sp in self.config.silence_phones: - if self.config.position_dependent_phones: - new_phones = [sp + x for x in ["", ""] + self.positions] - else: - new_phones = [sp] - f.write(" ".join(new_phones) + "\n") - for nsp in self.config.non_silence_phones: - if self.config.position_dependent_phones: - new_phones = [nsp + x for x in [""] + self.positions] - else: - new_phones = [nsp] - f.write(" ".join(new_phones) + "\n") - - def _write_phone_symbol_table(self) -> None: - """ - Write the phone mapping to the temporary directory - """ - outfile = os.path.join(self.output_directory, "phones.txt") - if os.path.exists(outfile): - return - with open(outfile, "w", encoding="utf8") as f: - for p, i in sorted(self.phone_mapping.items(), key=lambda x: x[1]): - f.write(f"{p} {i}\n") - - def _write_word_boundaries(self) -> None: - """ - Write the word boundaries file to the temporary directory - """ - boundary_path = os.path.join(self.output_directory, "phones", "word_boundary.txt") - boundary_int_path = os.path.join(self.output_directory, "phones", "word_boundary.int") - if os.path.exists(boundary_path) and os.path.exists(boundary_int_path): - return - with open(boundary_path, "w", encoding="utf8") as f, open( - boundary_int_path, "w", encoding="utf8" - ) as intf: - if self.config.position_dependent_phones: - for p in sorted(self.phone_mapping.keys(), key=lambda x: self.phone_mapping[x]): - if p == "" or p.startswith("#"): - continue - cat = "nonword" - if p.endswith("_B"): - cat = "begin" - elif p.endswith("_S"): - cat = "singleton" - elif p.endswith("_I"): - cat = "internal" - elif p.endswith("_E"): - cat = "end" - f.write(" ".join([p, cat]) + "\n") - intf.write(" ".join([str(self.phone_mapping[p]), cat]) + "\n") - def _write_word_file(self) -> None: """ Write the word mapping to the temporary directory """ - words_path = os.path.join(self.output_directory, "words.txt") - if os.path.exists(words_path): - return + words_path = os.path.join(self.dictionary_output_directory, "words.txt") if sys.platform == "win32": newline = "" else: @@ -631,6 +592,10 @@ def _write_align_lexicon(self) -> None: """ Write the alignment lexicon text file to the temporary directory """ + if self.root_dictionary is None: + phone_mapping = self.phone_mapping + else: + phone_mapping = self.root_dictionary.phone_mapping path = os.path.join(self.phones_dir, "align_lexicon.int") if os.path.exists(path): return @@ -647,7 +612,7 @@ def _write_align_lexicon(self) -> None: ): phones = list(pron["pronunciation"]) - if self.config.position_dependent_phones: + if self.position_dependent_phones: if len(phones) == 1: phones[0] += "_S" else: @@ -658,191 +623,48 @@ def _write_align_lexicon(self) -> None: phones[j] += "_E" else: phones[j] += "_I" - p = " ".join(str(self.phone_mapping[x]) for x in phones) + p = " ".join(str(phone_mapping[x]) for x in phones) f.write(f"{i} {i} {p}\n".format(i=i, p=p)) - def _write_topo(self) -> None: - """ - Write the topo file to the temporary directory - """ - filepath = os.path.join(self.output_directory, "topo") - if os.path.exists(filepath): - return - sil_transp = 1 / (self.config.num_silence_states - 1) - initial_transition = [ - self.topo_transition_template.format(x, sil_transp) - for x in range(self.config.num_silence_states - 1) - ] - middle_transition = [ - self.topo_transition_template.format(x, sil_transp) - for x in range(1, self.config.num_silence_states) - ] - final_transition = [ - self.topo_transition_template.format(self.config.num_silence_states - 1, 0.75), - self.topo_transition_template.format(self.config.num_silence_states, 0.25), - ] - with open(filepath, "w") as f: - f.write("\n") - f.write("\n") - f.write("\n") - phones = self.config.kaldi_non_silence_phones - f.write(f"{' '.join(str(self.phone_mapping[x]) for x in phones)}\n") - f.write("\n") - states = [ - self.topo_template.format(cur_state=x, next_state=x + 1) - for x in range(self.config.num_non_silence_states) - ] - f.write("\n".join(states)) - f.write(f"\n {self.config.num_non_silence_states} \n") - f.write("\n") - - f.write("\n") - f.write("\n") - - phones = self.config.kaldi_silence_phones - f.write(f"{' '.join(str(self.phone_mapping[x]) for x in phones)}\n") - f.write("\n") - states = [] - for i in range(self.config.num_silence_states): - if i == 0: - transition = " ".join(initial_transition) - elif i == self.config.num_silence_states - 1: - transition = " ".join(final_transition) - else: - transition = " ".join(middle_transition) - states.append(self.topo_sil_template.format(cur_state=i, transitions=transition)) - f.write("\n".join(states)) - f.write(f"\n {self.config.num_silence_states} \n") - f.write("\n") - f.write("\n") - - def _write_phone_sets(self) -> None: - """ - Write phone symbol sets to the temporary directory - """ - sharesplit = ["shared", "split"] - if not self.config.shared_silence_phones: - sil_sharesplit = ["not-shared", "not-split"] - else: - sil_sharesplit = sharesplit - - sets_file = os.path.join(self.output_directory, "phones", "sets.txt") - roots_file = os.path.join(self.output_directory, "phones", "roots.txt") - - sets_int_file = os.path.join(self.output_directory, "phones", "sets.int") - roots_int_file = os.path.join(self.output_directory, "phones", "roots.int") - if ( - os.path.exists(sets_file) - and os.path.exists(roots_file) - and os.path.exists(sets_int_file) - and os.path.exists(roots_int_file) - ): - return - - with open(sets_file, "w", encoding="utf8") as setf, open( - roots_file, "w", encoding="utf8" - ) as rootf, open(sets_int_file, "w", encoding="utf8") as setintf, open( - roots_int_file, "w", encoding="utf8" - ) as rootintf: - - # process silence phones - for i, sp in enumerate(self.config.silence_phones): - if self.config.position_dependent_phones: - mapped = [sp + x for x in [""] + self.positions] - else: - mapped = [sp] - setf.write(" ".join(mapped) + "\n") - setintf.write(" ".join(map(str, (self.phone_mapping[x] for x in mapped))) + "\n") - if i == 0: - line = sil_sharesplit + mapped - lineint = sil_sharesplit + [str(self.phone_mapping[x]) for x in mapped] - else: - line = sharesplit + mapped - lineint = sharesplit + [str(self.phone_mapping[x]) for x in mapped] - rootf.write(" ".join(line) + "\n") - rootintf.write(" ".join(lineint) + "\n") - - # process nonsilence phones - for nsp in sorted(self.config.non_silence_phones): - if self.config.position_dependent_phones: - mapped = [nsp + x for x in self.positions] - else: - mapped = [nsp] - setf.write(" ".join(mapped) + "\n") - setintf.write(" ".join(map(str, (self.phone_mapping[x] for x in mapped))) + "\n") - line = sharesplit + mapped - lineint = sharesplit + [str(self.phone_mapping[x]) for x in mapped] - rootf.write(" ".join(line) + "\n") - rootintf.write(" ".join(lineint) + "\n") - - def _write_extra_questions(self) -> None: - """ - Write extra questions symbols to the temporary directory - """ - phone_extra = os.path.join(self.phones_dir, "extra_questions.txt") - phone_extra_int = os.path.join(self.phones_dir, "extra_questions.int") - if os.path.exists(phone_extra) and os.path.exists(phone_extra_int): - return - with open(phone_extra, "w", encoding="utf8") as outf, open( - phone_extra_int, "w", encoding="utf8" - ) as intf: - silences = self.config.kaldi_silence_phones - outf.write(" ".join(silences) + "\n") - intf.write(" ".join(str(self.phone_mapping[x]) for x in silences) + "\n") - - non_silences = self.config.kaldi_non_silence_phones - outf.write(" ".join(non_silences) + "\n") - intf.write(" ".join(str(self.phone_mapping[x]) for x in non_silences) + "\n") - if self.config.position_dependent_phones: - for p in self.positions: - line = [x + p for x in sorted(self.config.non_silence_phones)] - outf.write(" ".join(line) + "\n") - intf.write(" ".join(str(self.phone_mapping[x]) for x in line) + "\n") - for p in [""] + self.positions: - line = [x + p for x in sorted(self.config.silence_phones)] - outf.write(" ".join(line) + "\n") - intf.write(" ".join(str(self.phone_mapping[x]) for x in line) + "\n") - - def _write_disambig(self) -> None: - """ - Write disambiguation symbols to the temporary directory - """ - disambig = os.path.join(self.phones_dir, "disambiguation_symbols.txt") - disambig_int = os.path.join(self.phones_dir, "disambiguation_symbols.int") - if os.path.exists(disambig) and os.path.exists(disambig_int): - return - with open(disambig, "w", encoding="utf8") as outf, open( - disambig_int, "w", encoding="utf8" - ) as intf: - for d in sorted( - self.config.disambiguation_symbols, key=lambda x: self.phone_mapping[x] - ): - outf.write(f"{d}\n") - intf.write(f"{self.phone_mapping[d]}\n") - - def _write_fst_binary(self, write_disambiguation: Optional[bool] = False) -> None: + def _write_fst_binary( + self, + write_disambiguation: Optional[bool] = False, + ) -> None: """ Write the binary fst file to the temporary directory + See Also + -------- + :kaldi_src:`fstaddselfloops` + Relevant Kaldi binary + :openfst_src:`fstcompile` + Relevant OpenFst binary + :openfst_src:`fstarcsort` + Relevant OpenFst binary + Parameters ---------- write_disambiguation: bool, optional Flag for including disambiguation symbols """ if write_disambiguation: - lexicon_fst_path = os.path.join(self.output_directory, "lexicon_disambig.text.fst") - output_fst = os.path.join(self.output_directory, "L_disambig.fst") + lexicon_fst_path = os.path.join( + self.dictionary_output_directory, "lexicon_disambig.text.fst" + ) + output_fst = os.path.join(self.dictionary_output_directory, "L_disambig.fst") else: - lexicon_fst_path = os.path.join(self.output_directory, "lexicon.text.fst") - output_fst = os.path.join(self.output_directory, "L.fst") - if os.path.exists(output_fst): - return - - phones_file_path = os.path.join(self.output_directory, "phones.txt") - words_file_path = os.path.join(self.output_directory, "words.txt") + lexicon_fst_path = os.path.join(self.dictionary_output_directory, "lexicon.text.fst") + output_fst = os.path.join(self.dictionary_output_directory, "L.fst") + if self.root_dictionary is not None: + phone_mapping = self.root_dictionary.phone_mapping + phones_file_path = self.root_dictionary.phone_symbol_table_path + else: + phone_mapping = self.phone_mapping + phones_file_path = self.phone_symbol_table_path + words_file_path = os.path.join(self.dictionary_output_directory, "words.txt") - log_path = os.path.join(self.output_directory, "fst.log") - temp_fst_path = os.path.join(self.output_directory, "temp.fst") + log_path = os.path.join(self.dictionary_output_directory, "fst.log") + temp_fst_path = os.path.join(self.dictionary_output_directory, "temp.fst") with open(log_path, "w") as log_file: compile_proc = subprocess.Popen( [ @@ -858,11 +680,15 @@ def _write_fst_binary(self, write_disambiguation: Optional[bool] = False) -> Non ) compile_proc.communicate() if write_disambiguation: - temp2_fst_path = os.path.join(self.output_directory, "temp2.fst") - phone_disambig_path = os.path.join(self.output_directory, "phone_disambig.txt") - word_disambig_path = os.path.join(self.output_directory, "word_disambig.txt") + temp2_fst_path = os.path.join(self.dictionary_output_directory, "temp2.fst") + word_disambig_path = os.path.join( + self.dictionary_output_directory, "word_disambig0.txt" + ) + phone_disambig_path = os.path.join( + self.dictionary_output_directory, "phone_disambig0.txt" + ) with open(phone_disambig_path, "w") as f: - f.write(str(self.phone_mapping["#0"])) + f.write(str(phone_mapping["#0"])) with open(word_disambig_path, "w") as f: f.write(str(self.words_mapping["#0"])) selfloop_proc = subprocess.Popen( @@ -901,27 +727,26 @@ def _write_basic_fst_text(self) -> None: """ Write the L.fst text file to the temporary directory """ - sil_disambiguation = None nonoptional_silence = None optional_silence_phone = None - lexicon_fst_path = os.path.join(self.output_directory, "lexicon.text.fst") + lexicon_fst_path = os.path.join(self.dictionary_output_directory, "lexicon.text.fst") start_state = 0 silence_state = 0 silence_cost = 0 no_silence_cost = 0 loop_state = 0 next_state = 1 - if self.config.silence_probability: - optional_silence_phone = self.config.optional_silence_phone - nonoptional_silence = self.config.nonoptional_silence_phone + if self.silence_probability: + optional_silence_phone = self.optional_silence_phone + nonoptional_silence = self.nonoptional_silence_phone - silence_cost = -1 * math.log(self.config.silence_probability) - no_silence_cost = -1 * math.log(1.0 - self.config.silence_probability) + silence_cost = -1 * math.log(self.silence_probability) + no_silence_cost = -1 * math.log(1.0 - self.silence_probability) loop_state = 1 silence_state = 2 with open(lexicon_fst_path, "w", encoding="utf8") as outf: - if self.config.silence_probability: + if self.silence_probability: outf.write( "\t".join( map(str, [start_state, loop_state, "", "", no_silence_cost]) @@ -938,45 +763,13 @@ def _write_basic_fst_text(self) -> None: ) + "\n" ) # silence - if sil_disambiguation is None: - outf.write( - "\t".join( - map(str, [silence_state, loop_state, optional_silence_phone, ""]) - ) - + "\n" - ) # no cost - next_state = 3 - else: - silence_disambiguation_state = next_state - next_state += 1 - outf.write( - "\t".join( - map( - str, - [ - silence_state, - silence_disambiguation_state, - optional_silence_phone, - "", - ], - ) - ) - + "\n" - ) - outf.write( - "\t".join( - map( - str, - [ - silence_disambiguation_state, - loop_state, - sil_disambiguation, - "", - ], - ) - ) - + "\n" + outf.write( + "\t".join( + map(str, [silence_state, loop_state, optional_silence_phone, ""]) ) + + "\n" + ) # no cost + next_state = 3 for w in sorted(self.words.keys()): if self.exclude_for_alignment(w): @@ -987,7 +780,7 @@ def _write_basic_fst_text(self) -> None: ): phones = list(pron["pronunciation"]) prob = pron["probability"] - if self.config.position_dependent_phones: + if self.position_dependent_phones: if len(phones) == 1: phones[0] += "_S" else: @@ -1016,7 +809,7 @@ def _write_basic_fst_text(self) -> None: current_state = next_state next_state += 1 else: # transition on last phone to loop state - if self.config.silence_probability: + if self.silence_probability: outf.write( f"{current_state}\t{loop_state}\t{p}\t{word_or_eps}\t{local_no_silence_cost}\n" ) @@ -1031,22 +824,34 @@ def _write_basic_fst_text(self) -> None: outf.write(f"{loop_state}\t0\n") - def _write_fst_text_disambiguated(self) -> None: + def _write_fst_text_disambiguated( + self, multispeaker_dictionary: Optional[MultispeakerDictionaryMixin] = None + ) -> None: """ Write the text L_disambig.fst file to the temporary directory + + Parameters + ---------- + multispeaker_dictionary: MultispeakerDictionaryMixin, optional + Main dictionary with phone mappings """ - lexicon_fst_path = os.path.join(self.output_directory, "lexicon_disambig.text.fst") - sil_disambiguation = f"#{self.config.max_disambiguation_symbol + 1}" - assert self.config.silence_probability + lexicon_fst_path = os.path.join( + self.dictionary_output_directory, "lexicon_disambig.text.fst" + ) + if multispeaker_dictionary is not None: + sil_disambiguation = f"#{multispeaker_dictionary.max_disambiguation_symbol + 1}" + else: + sil_disambiguation = f"#{self.max_disambiguation_symbol + 1}" + assert self.silence_probability start_state = 0 loop_state = 1 silence_state = 2 next_state = 3 - silence_phone = self.config.nonoptional_silence_phone + silence_phone = self.nonoptional_silence_phone - silence_cost = -1 * math.log(self.config.silence_probability) - no_silence_cost = -1 * math.log(1 - self.config.silence_probability) + silence_cost = -1 * math.log(self.silence_probability) + no_silence_cost = -1 * math.log(1 - self.silence_probability) with open(lexicon_fst_path, "w", encoding="utf8") as outf: outf.write( @@ -1075,7 +880,7 @@ def _write_fst_text_disambiguated(self) -> None: phones = list(pron["pronunciation"]) prob = pron["probability"] disambig_symbol = pron["disambiguation"] - if self.config.position_dependent_phones: + if self.position_dependent_phones: if len(phones) == 1: phones[0] += "_S" else: @@ -1116,3 +921,29 @@ def _write_fst_text_disambiguated(self) -> None: ) outf.write(f"{loop_state}\t0.0\n") + + +class PronunciationDictionary(PronunciationDictionaryMixin): + """ + Class for processing pronunciation dictionaries + + See Also + -------- + :class:`~montreal_forced_aligner.dictionary.pronunciation.PronunciationDictionaryMixin` + For acoustic model training parsing parameters + """ + + @property + def data_source_identifier(self) -> str: + """Dictionary name""" + return f"{self.name}" + + @property + def identifier(self) -> str: + """Dictionary name""" + return f"{self.data_source_identifier}" + + @property + def output_directory(self) -> str: + """Temporary directory for the dictionary""" + return os.path.join(self.temporary_directory, self.identifier) diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py index e8f6627f..138ef7fb 100644 --- a/montreal_forced_aligner/exceptions.py +++ b/montreal_forced_aligner/exceptions.py @@ -5,15 +5,18 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple +import logging +import sys +from typing import TYPE_CHECKING, Collection, Optional from colorama import Fore, Style -from .helper import comma_join +from montreal_forced_aligner.helper import comma_join if TYPE_CHECKING: - from .dictionary import PronunciationDictionary - from .models import G2PModel + from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionaryMixin + from montreal_forced_aligner.models import G2PModel + from montreal_forced_aligner.textgrid import CtmInterval __all__ = [ @@ -125,6 +128,26 @@ def __str__(self) -> str: return f"{self.error_text(type(self).__name__)}: {self.message}" +class PlatformError(MFAError): + """ + Exception class for platform compatibility issues + + Parameters + ---------- + functionality_name: str + Functionality not available on the current platform + """ + + def __init__(self, functionality_name): + super().__init__() + self.message = f"Functionality for {self.emphasized_text(functionality_name)} is not available on {self.error_text(sys.platform)}." + if sys.platform == "win32": + self.message += ( + f" If you'd like to use {self.emphasized_text(functionality_name)} on Windows, please follow the MFA installation " + f"instructions for the Windows Subsystem for Linux (WSL)." + ) + + class ThirdpartyError(MFAError): """ Exception class for errors in third party binary (usually Kaldi or OpenFst) @@ -307,6 +330,7 @@ def __init__(self, file_name: str, error: str): MFAError.__init__(self) self.file_name = file_name self.error = error + self.message = f"Reading {self.emphasized_text(self.file_name)} has the following error:\n\n{self.error}" class SoxError(CorpusReadError): @@ -334,11 +358,11 @@ class AlignmentError(MFAError): Parameters ---------- - error_logs: List[str] + error_logs: list[str] List of Kaldi log files with errors """ - def __init__(self, error_logs: List[str]): + def __init__(self, error_logs: list[str]): super().__init__() output = "\n".join(error_logs) self.message = ( @@ -353,12 +377,12 @@ class AlignmentExportError(AlignmentError): Parameters ---------- - error_dict: Dict[Tuple[str, int], str] + error_dict: dict[tuple[str, int], str] Error dictionary mapping export stage and job to the error encountered """ - def __init__(self, error_dict: Dict[Tuple[str, int], str]): + def __init__(self, error_dict: dict[tuple[str, int], str]): MFAError.__init__(self) message = "Error was encountered in processing CTMs:\n\n" @@ -367,6 +391,23 @@ def __init__(self, error_dict: Dict[Tuple[str, int], str]): self.message = message +class CtmError(AlignmentError): + """ + Class for errors in creating CTM intervals + + Parameters + ---------- + ctm: CtmInterval + CTM interval that was not parsed correctly + + """ + + def __init__(self, ctm: CtmInterval): + MFAError.__init__(self) + + self.message = f"Error was encountered in processing CTM interval: {ctm}" + + class NoSuccessfulAlignments(AlignerError): """ Class for errors where nothing could be aligned @@ -402,11 +443,11 @@ class PronunciationOrthographyMismatchError(AlignerError): ---------- g2p_model: :class:`~montreal_forced_aligner.models.G2PModel` Specified G2P model - dictionary: :class:`~montreal_forced_aligner.dictionary.PronunciationDictionary` + dictionary: :class:`~montreal_forced_aligner.dictionary.pronunciation.PronunciationDictionaryMixin` Specified dictionary """ - def __init__(self, g2p_model: G2PModel, dictionary: PronunciationDictionary): + def __init__(self, g2p_model: G2PModel, dictionary: PronunciationDictionaryMixin): super().__init__() missing_graphs = dictionary.graphemes - set(g2p_model.meta["graphemes"]) missing_graphs = [f"{self.error_text(x)}" for x in sorted(missing_graphs)] @@ -452,12 +493,12 @@ class PretrainedModelNotFoundError(ArgumentError): Model name model_type: str, optional Model type searched - available: List[str], optional + available: list[str], optional List of models that were found """ def __init__( - self, name: str, model_type: Optional[str] = None, available: Optional[List[str]] = None + self, name: str, model_type: Optional[str] = None, available: Optional[list[str]] = None ): super().__init__() extra = "" @@ -478,11 +519,11 @@ class MultipleModelTypesFoundError(ArgumentError): ---------- name: str Model name - possible_model_types: List[str] + possible_model_types: list[str] List of model types that have a model with the given name """ - def __init__(self, name: str, possible_model_types: List[str]): + def __init__(self, name: str, possible_model_types: list[str]): super().__init__() possible_model_types = [f"{self.error_text(x)}" for x in possible_model_types] @@ -502,11 +543,11 @@ class ModelExtensionError(ArgumentError): Model name model_type: str Model type - extensions: List[str] + extensions: list[str] Extensions that the model supports """ - def __init__(self, name: str, model_type: str, extensions: List[str]): + def __init__(self, name: str, model_type: str, extensions: list[str]): super().__init__() extra = "" if model_type: @@ -528,7 +569,7 @@ class ModelTypeNotSupportedError(ArgumentError): ---------- model_type: str Model type - model_types: List[str] + model_types: list[str] List of supported model types """ @@ -549,6 +590,19 @@ class ConfigError(MFAError): pass +class RootDirectoryError(ConfigError): + """ + Exception class for errors using the MFA_ROOT_DIR + """ + + def __init__(self, temporary_directory, variable): + super().__init__() + self.message = ( + f"Could not create a root MFA temporary directory (tried {self.error_text(temporary_directory)}), " + f"please specify a write-able directory via the {self.emphasized_text(variable)} environment variable." + ) + + class TrainerError(MFAError): """ Exception class for errors in trainers @@ -589,13 +643,13 @@ class KaldiProcessingError(MFAError): Parameters ---------- - error_logs: List[str] + error_logs: list[str] List of Kaldi logs that had errors log_file: str, optional Overall log file to find more information """ - def __init__(self, error_logs: List[str], log_file: Optional[str] = None): + def __init__(self, error_logs: list[str], log_file: Optional[str] = None): super().__init__() self.message = ( f"There were {len(error_logs)} job(s) with errors when running Kaldi binaries." @@ -605,18 +659,19 @@ def __init__(self, error_logs: List[str], log_file: Optional[str] = None): self.error_logs = error_logs self.log_file = log_file - def update_log_file(self, log_file: str) -> None: + def update_log_file(self, logger: logging.Logger) -> None: """ Update the log file output Parameters ---------- - log_file: str - Path to log file + logger: logging.Logger + Logger """ - self.log_file = log_file + if logger.handlers: + self.log_file = logger.handlers[0].baseFilename self.message = ( f"There were {len(self.error_logs)} job(s) with errors when running Kaldi binaries." ) if self.log_file is not None: - self.message += f" For more details, please check {self.error_text(log_file)}" + self.message += f" For more details, please check {self.error_text(self.log_file)}" diff --git a/montreal_forced_aligner/g2p/__init__.py b/montreal_forced_aligner/g2p/__init__.py index c22599b6..19e67725 100644 --- a/montreal_forced_aligner/g2p/__init__.py +++ b/montreal_forced_aligner/g2p/__init__.py @@ -5,10 +5,20 @@ """ -from .generator import PyniniDictionaryGenerator -from .trainer import PyniniTrainer +from montreal_forced_aligner.g2p.generator import ( + OrthographicCorpusGenerator, + OrthographicWordListGenerator, + PyniniCorpusGenerator, + PyniniWordListGenerator, +) +from montreal_forced_aligner.g2p.trainer import PyniniTrainer -__all__ = ["generator", "trainer", "PyniniTrainer", "PyniniDictionaryGenerator"] - -PyniniTrainer.__module__ = "montreal_forced_aligner.g2p" -PyniniDictionaryGenerator.__module__ = "montreal_forced_aligner.g2p" +__all__ = [ + "generator", + "trainer", + "PyniniTrainer", + "PyniniCorpusGenerator", + "PyniniWordListGenerator", + "OrthographicCorpusGenerator", + "OrthographicWordListGenerator", +] diff --git a/montreal_forced_aligner/g2p/generator.py b/montreal_forced_aligner/g2p/generator.py index da75a209..aa5039c5 100644 --- a/montreal_forced_aligner/g2p/generator.py +++ b/montreal_forced_aligner/g2p/generator.py @@ -2,20 +2,22 @@ from __future__ import annotations import functools -import logging import multiprocessing as mp import os import queue import sys import time import traceback -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Union import tqdm -from ..config import TEMP_DIR +from ..abc import TopLevelMfaWorker +from ..corpus.text_corpus import TextCorpusMixin from ..exceptions import G2PError -from ..multiprocessing import Counter, Stopped +from ..models import G2PModel +from ..utils import Counter, Stopped +from .mixins import G2PTopLevelMixin try: import pynini @@ -32,10 +34,15 @@ if TYPE_CHECKING: SpeakerCharacterType = Union[str, int] - from ..models import G2PModel -__all__ = ["Rewriter", "RewriterWorker", "PyniniDictionaryGenerator"] +__all__ = [ + "Rewriter", + "RewriterWorker", + "PyniniGenerator", + "PyniniCorpusGenerator", + "PyniniWordListGenerator", +] class Rewriter: @@ -68,7 +75,7 @@ class RewriterWorker(mp.Process): def __init__( self, job_q: mp.Queue, - return_dict: Dict[str, Union[str, Any]], + return_dict: dict[str, Union[str, Any]], rewriter: Rewriter, counter: Counter, stopped: Stopped, @@ -105,7 +112,7 @@ def run(self) -> None: return -def clean_up_word(word: str, graphemes: Set[str]) -> Tuple[str, List[str]]: +def clean_up_word(word: str, graphemes: set[str]) -> tuple[str, list[str]]: """ Clean up word by removing graphemes not in a specified set @@ -113,7 +120,7 @@ def clean_up_word(word: str, graphemes: Set[str]) -> Tuple[str, List[str]]: ---------- word : str Input string - graphemes: Set[str] + graphemes: set[str] Set of allowable graphemes Returns @@ -133,76 +140,90 @@ def clean_up_word(word: str, graphemes: Set[str]) -> Tuple[str, List[str]]: return "".join(new_word), missing_graphemes -class PyniniDictionaryGenerator: +class OrthographyGenerator(G2PTopLevelMixin): """ - Class for generating pronunciations from a G2P model + Abstract mixin class for generating "pronunciations" based off the orthographic word + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.mixins.G2PTopLevelMixin` + For top level G2P generation parameters """ - def __init__( - self, - g2p_model: G2PModel, - word_set: Collection[str], - temp_directory: Optional[str] = None, - num_jobs: int = 3, - num_pronunciations: int = 1, - logger: Optional[logging.Logger] = None, - ): - super(PyniniDictionaryGenerator, self).__init__() - if not temp_directory: - temp_directory = TEMP_DIR - temp_directory = os.path.join(temp_directory, "G2P") - self.model = g2p_model - - self.temp_directory = os.path.join(temp_directory, self.model.name) - log_dir = os.path.join(self.temp_directory, "logging") - os.makedirs(log_dir, exist_ok=True) - self.log_file = os.path.join(log_dir, "g2p.log") - if logger is not None: - self.logger = logger - else: - self.logger = logging.getLogger("g2p") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - self.words = word_set - self.num_jobs = num_jobs - self.num_pronunciations = num_pronunciations - - def generate(self) -> Dict[str, List[str]]: + def generate_pronunciations(self) -> dict[str, list[str]]: + """ + Generate pronunciations for the word set + + Returns + ------- + dict[str, list[str]] + Mapping of words to their "pronunciation" + """ + pronunciations = {} + for word in self.words_to_g2p: + pronunciation = list(word) + pronunciations[word] = pronunciation + return pronunciations + + +class PyniniGenerator(G2PTopLevelMixin): + """ + Class for generating pronunciations from a Pynini G2P model + + Parameters + ---------- + g2p_model_path: str + Path to G2P model + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.mixins.G2PTopLevelMixin` + For top level G2P generation parameters + + Attributes + ---------- + g2p_model: G2PModel + G2P model + """ + + def __init__(self, g2p_model_path: str, **kwargs): + self.g2p_model = G2PModel(g2p_model_path) + super().__init__(**kwargs) + + def generate_pronunciations(self) -> dict[str, list[str]]: """ Generate pronunciations Returns ------- - Dict + dict[str, list[str]] Mappings of keys to their generated pronunciations """ - if self.model.meta["architecture"] == "phonetisaurus": + if self.g2p_model.meta["architecture"] == "phonetisaurus": raise G2PError( "Previously trained Phonetisaurus models from 1.1 and earlier are not currently supported. " "Please retrain your model using 2.0+" ) input_token_type = "utf8" - fst = pynini.Fst.read(self.model.fst_path) + fst = pynini.Fst.read(self.g2p_model.fst_path) output_token_type = "utf8" - if self.model.sym_path is not None and os.path.exists(self.model.sym_path): - output_token_type = pynini.SymbolTable.read_text(self.model.sym_path) + if self.g2p_model.sym_path is not None and os.path.exists(self.g2p_model.sym_path): + output_token_type = pynini.SymbolTable.read_text(self.g2p_model.sym_path) rewriter = Rewriter(fst, input_token_type, output_token_type, self.num_pronunciations) ind = 0 - num_words = len(self.words) - words = list(self.words) + num_words = len(self.words_to_g2p) + words = list(self.words_to_g2p) begin = time.time() last_value = 0 missing_graphemes = set() - print("Generating pronunciations...") + self.log_info("Generating pronunciations...") to_return = {} if num_words < 30 or self.num_jobs < 2: for word in words: - w, m = clean_up_word(word, self.model.meta["graphemes"]) + w, m = clean_up_word(word, self.g2p_model.meta["graphemes"]) missing_graphemes.update(m) if not w: continue @@ -219,7 +240,7 @@ def generate(self) -> Dict[str, List[str]]: if ind == num_words: break try: - w, m = clean_up_word(words[ind], self.model.meta["graphemes"]) + w, m = clean_up_word(words[ind], self.g2p_model.meta["graphemes"]) missing_graphemes.update(m) if not w: ind += 1 @@ -239,7 +260,7 @@ def generate(self) -> Dict[str, List[str]]: while True: if ind == num_words: break - w, m = clean_up_word(words[ind], self.model.meta["graphemes"]) + w, m = clean_up_word(words[ind], self.g2p_model.meta["graphemes"]) missing_graphemes.update(m) if not w: ind += 1 @@ -254,32 +275,175 @@ def generate(self) -> Dict[str, List[str]]: p.join() if "MFA_EXCEPTION" in return_dict: element, exc = return_dict["MFA_EXCEPTION"] - print(element) + self.log_error(f"Encountered error processing: {element}") raise exc - for w in self.words: + for w in self.words_to_g2p: if w in return_dict: to_return[w] = return_dict[w] - self.logger.debug(f"Processed {num_words} in {time.time() - begin} seconds") + self.log_debug(f"Processed {num_words} in {time.time() - begin} seconds") return to_return - def output(self, outfile: str) -> None: - """ - Output pronunciations to text file - Parameters - ---------- - outfile: str - Path to save - """ - results = self.generate() - with open(outfile, "w", encoding="utf8") as f: - for (word, pronunciation) in results.items(): - if not pronunciation: - continue - if isinstance(pronunciation, list): - for p in pronunciation: - if not p: - continue - f.write(f"{word}\t{p}\n") - else: - f.write(f"{word}\t{pronunciation}\n") +class PyniniWordListGenerator(PyniniGenerator, TopLevelMfaWorker): + """ + Top-level worker for generating pronunciations from a word list and a Pynini G2P model + + Parameters + ---------- + word_list_path: str + Path to word list file + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator` + For Pynini G2P generation parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + + Attributes + ---------- + word_list: list[str] + Word list to generate pronunciations + """ + + def __init__(self, word_list_path: str, **kwargs): + self.word_list_path = word_list_path + self.word_list = [] + super().__init__(**kwargs) + + @property + def data_directory(self) -> str: + """Data directory""" + return self.working_directory + + @property + def data_source_identifier(self) -> str: + """Name of the word list file""" + return os.path.splitext(os.path.basename(self.word_list_path))[0] + + def setup(self) -> None: + """Set up the G2P generator""" + if self.initialized: + return + with open(self.word_list_path, "r", encoding="utf8") as f: + for line in f: + self.word_list.extend(line.strip().split()) + if not self.include_bracketed: + self.word_list = [x for x in self.word_list if not self.check_bracketed(x)] + self.g2p_model.validate(self.words_to_g2p) + self.initialized = True + + @property + def words_to_g2p(self) -> list[str]: + """Words to produce pronunciations""" + return self.word_list + + +class PyniniCorpusGenerator(PyniniGenerator, TextCorpusMixin, TopLevelMfaWorker): + """ + Top-level worker for generating pronunciations from a corpus and a Pynini G2P model + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator` + For Pynini G2P generation parameters + :class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin` + For corpus parsing parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def setup(self) -> None: + """Set up the pronunciation generator""" + if self.initialized: + return + self._load_corpus() + self.g2p_model.validate(self.words_to_g2p) + self.initialized = True + + @property + def words_to_g2p(self) -> list[str]: + """Words to produce pronunciations""" + word_list = self.corpus_word_set + if not self.include_bracketed: + word_list = [x for x in word_list if not self.check_bracketed(x)] + return word_list + + +class OrthographicCorpusGenerator(OrthographyGenerator, TextCorpusMixin, TopLevelMfaWorker): + """ + Top-level class for generating "pronunciations" from the orthography of a corpus + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.generator.OrthographyGenerator` + For orthography-based G2P generation parameters + :class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin` + For corpus parsing parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def setup(self) -> None: + """Set up the pronunciation generator""" + if self.initialized: + return + self._load_corpus() + self.initialized = True + + @property + def words_to_g2p(self) -> list[str]: + """Words to produce pronunciations""" + word_list = self.corpus_word_set + if not self.include_bracketed: + word_list = [x for x in word_list if not self.check_bracketed(x)] + return word_list + + +class OrthographicWordListGenerator(OrthographyGenerator, TopLevelMfaWorker): + """ + Top-level class for generating "pronunciations" from the orthography of a corpus + + Parameters + ---------- + word_list_path: str + Path to word list file + See Also + -------- + :class:`~montreal_forced_aligner.g2p.generator.OrthographyGenerator` + For orthography-based G2P generation parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + + Attributes + ---------- + word_list: list[str] + Word list to generate pronunciations + """ + + def __init__(self, word_list_path: str, **kwargs): + super().__init__(**kwargs) + self.word_list_path = word_list_path + self.word_list = [] + + def setup(self) -> None: + """Set up the pronunciation generator""" + if self.initialized: + return + with open(self.word_list_path, "r", encoding="utf8") as f: + for line in f: + self.word_list.extend(line.strip().split()) + if not self.include_bracketed: + self.word_list = [x for x in self.word_list if not self.check_bracketed(x)] + self.initialized = True + + @property + def words_to_g2p(self) -> list[str]: + """Words to produce pronunciations""" + return self.word_list diff --git a/montreal_forced_aligner/g2p/mixins.py b/montreal_forced_aligner/g2p/mixins.py new file mode 100644 index 00000000..fa6212c3 --- /dev/null +++ b/montreal_forced_aligner/g2p/mixins.py @@ -0,0 +1,96 @@ +from abc import ABCMeta, abstractmethod + +from montreal_forced_aligner.abc import MfaWorker +from montreal_forced_aligner.dictionary.mixins import DictionaryMixin + + +class G2PMixin(metaclass=ABCMeta): + """ + Abstract mixin class for G2P functionality + + Parameters + ---------- + include_bracketed: bool + Flag for whether to generate pronunciations for fully bracketed words, defaults to False + num_pronunciations: int + Number of pronunciations to generate, defaults to 1 + """ + + def __init__(self, include_bracketed: bool = False, num_pronunciations: int = 1, **kwargs): + super().__init__(**kwargs) + self.num_pronunciations = num_pronunciations + self.include_bracketed = include_bracketed + + @abstractmethod + def generate_pronunciations(self) -> dict[str, list[str]]: + """ + Generate pronunciations + + Returns + ------- + dict[str, list[str]] + Mappings of keys to their generated pronunciations + """ + ... + + @property + @abstractmethod + def words_to_g2p(self): + """Words to produce pronunciations""" + ... + + +class G2PTopLevelMixin(MfaWorker, DictionaryMixin, G2PMixin): + """ + Abstract mixin class for top-level G2P functionality + + See Also + -------- + :class:`~montreal_forced_aligner.abc.MfaWorker` + For base MFA parameters + :class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin` + For dictionary parsing parameters + :class:`~montreal_forced_aligner.g2p.mixins.G2PMixin` + For base G2P parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def generate_pronunciations(self) -> dict[str, list[str]]: + """ + Generate pronunciations + + Returns + ------- + dict[str, list[str]] + Mappings of keys to their generated pronunciations + """ + raise NotImplementedError + + @property + def workflow_identifier(self) -> str: + """G2P identifier""" + return "g2p" + + def export_pronunciations(self, output_file_path: str) -> None: + """ + Output pronunciations to text file + + Parameters + ---------- + output_file_path: str + Path to save + """ + results = self.generate_pronunciations() + with open(output_file_path, "w", encoding="utf8") as f: + for (word, pronunciation) in results.items(): + if not pronunciation: + continue + if isinstance(pronunciation, list): + for p in pronunciation: + if not p: + continue + f.write(f"{word}\t{p}\n") + else: + f.write(f"{word}\t{pronunciation}\n") diff --git a/montreal_forced_aligner/g2p/trainer.py b/montreal_forced_aligner/g2p/trainer.py index e38dd0d0..7ab0317f 100644 --- a/montreal_forced_aligner/g2p/trainer.py +++ b/montreal_forced_aligner/g2p/trainer.py @@ -14,15 +14,16 @@ import sys import time import traceback -from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Any, Callable, NamedTuple, Optional import tqdm -from ..config import TEMP_DIR -from ..helper import score -from ..models import G2PModel -from ..multiprocessing import Counter, Stopped -from .generator import PyniniDictionaryGenerator +from montreal_forced_aligner.abc import MetaDict, MfaWorker, TopLevelMfaWorker, TrainerMixin +from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionaryMixin +from montreal_forced_aligner.g2p.generator import PyniniGenerator +from montreal_forced_aligner.helper import score +from montreal_forced_aligner.models import G2PModel +from montreal_forced_aligner.utils import Counter, Stopped try: import pynini @@ -35,27 +36,23 @@ except ImportError: pynini = None pywrapfst = None - TokenType = None + TokenType = Optional[str] Fst = None def convert(x): + """stub function""" pass G2P_DISABLED = True -if TYPE_CHECKING: - from ..abc import Dictionary, DictionaryEntryType - from ..config.train_g2p_config import TrainG2PConfig - -Labels = List[Any] +Labels = list[Any] TOKEN_TYPES = ["byte", "utf8"] -DEV_NULL = open(os.devnull, "w") INF = float("inf") RAND_MAX = 32767 -__all__ = ["RandomStartWorker", "PairNGramAligner", "PyniniTrainer"] +__all__ = ["RandomStartWorker", "PairNGramAligner", "PyniniTrainer", "G2PTrainer"] class RandomStart(NamedTuple): @@ -67,63 +64,7 @@ class RandomStart(NamedTuple): p_path: str c_path: str tempdir: str - train_opts: List[str] - - -def compute_validation_errors( - gold_values: Dict[str, List[Dict[str, str]]], - hypothesis_values: Dict[str, List[str]], - num_jobs: int = 3, -) -> Tuple[float, float]: - """ - Computes validation errors - - Parameters - ---------- - gold_values: Dict - Gold labels - hypothesis_values: Dict - Hypothesis labels - num_jobs: int - Number of jobs to use - - Returns - ------- - float - Word error rate - float - Phone error rate - """ - # Word-level measures. - correct = 0 - incorrect = 0 - # Label-level measures. - total_edits = 0 - total_length = 0 - # Since the edit distance algorithm is quadratic, let's do this with - # multiprocessing. - with mp.Pool(num_jobs) as pool: - to_comp = [] - for word, hyp in hypothesis_values.items(): - g = gold_values[word][0]["pronunciation"] - hyp = [h.split(" ") for h in hyp] - to_comp.append((g, hyp)) - gen = pool.starmap(score, to_comp) - for (edits, length) in gen: - if edits == 0: - correct += 1 - else: - incorrect += 1 - total_edits += edits - total_length += length - for w, gold in gold_values.items(): - if w not in hypothesis_values: - incorrect += 1 - gold = gold[0]["pronunciation"] - total_edits += len(gold) - total_length += len(gold) - - return 100 * incorrect / (correct + incorrect), 100 * total_edits / total_length + train_opts: list[str] class RandomStartWorker(mp.Process): @@ -134,7 +75,7 @@ class RandomStartWorker(mp.Process): def __init__( self, job_q: mp.Queue, - return_dict: Dict, + return_dict: dict, function: Callable, counter: Counter, stopped: Stopped, @@ -161,7 +102,7 @@ def run(self) -> None: self.return_dict[fst_path] = likelihood except Exception: self.stopped.stop() - self.return_dict["error"] = args, Exception( + self.return_dict["MFA_ERROR"] = args, Exception( traceback.format_exception(*sys.exc_info()) ) self.counter.increment() @@ -173,14 +114,14 @@ class PairNGramAligner: _compactor = functools.partial(convert, fst_type="compact_string") - def __init__(self, temp_directory: str): - self.tempdir = temp_directory - self.g_path = os.path.join(self.tempdir, "g.far") - self.p_path = os.path.join(self.tempdir, "p.far") - self.c_path = os.path.join(self.tempdir, "c.fst") - self.align_path = os.path.join(self.tempdir, "align.fst") - self.afst_path = os.path.join(self.tempdir, "afst.far") - self.align_log_path = os.path.join(self.tempdir, "align.log") + def __init__(self, working_directory: str): + self.working_directory = working_directory + self.g_path = os.path.join(self.working_directory, "g.far") + self.p_path = os.path.join(self.working_directory, "p.far") + self.c_path = os.path.join(self.working_directory, "c.fst") + self.align_path = os.path.join(self.working_directory, "align.fst") + self.afst_path = os.path.join(self.working_directory, "afst.far") + self.align_log_path = os.path.join(self.working_directory, "align.log") self.logger = logging.getLogger("g2p_aligner") self.logger.setLevel(logging.DEBUG) @@ -236,7 +177,7 @@ def align( self.logger.info("Success! FAR path: %s; encoder path: %s", far_path, encoder_path) @staticmethod - def _label_union(labels: Set[int], epsilon: bool) -> Fst: + def _label_union(labels: set[int], epsilon: bool) -> Fst: """Creates FSA over a union of the labels.""" side = pynini.Fst() src = side.add_state() @@ -268,8 +209,8 @@ def _lexicon_covering( ) -> None: """Builds covering grammar and lexicon FARs.""" # Sets of labels for the covering grammar. - g_labels: Set[int] = set() - p_labels: Set[int] = set() + g_labels: set[int] = set() + p_labels: set[int] = set() self.logger.info("Constructing grapheme and phoneme FARs") g_writer = pywrapfst.FarWriter.create(self.g_path) p_writer = pywrapfst.FarWriter.create(self.p_path) @@ -300,7 +241,7 @@ def _lexicon_covering( covering.write(self.c_path) @staticmethod - def _random_start(random_start: RandomStart) -> Tuple[str, float]: + def _random_start(random_start: RandomStart) -> tuple[str, float]: """Performs a single random start.""" start = time.time() logger = logging.getLogger("g2p_aligner") @@ -384,7 +325,7 @@ def _alignments( self.g_path, self.p_path, self.c_path, - self.tempdir, + self.working_directory, train_opts, ) ) @@ -397,8 +338,7 @@ def _alignments( job_queue = mp.JoinableQueue(cores + 2) # Actually runs starts. - self.logger.info("Random starts") - print("Calculating alignments...") + self.logger.info("Calculating alignments...") begin = time.time() last_value = 0 ind = 0 @@ -440,8 +380,8 @@ def _alignments( job_queue.join() for p in procs: p.join() - if "error" in return_dict: - element, exc = return_dict["error"] + if "MFA_ERROR" in return_dict: + element, exc = return_dict["MFA_ERROR"] print(element) raise exc (best_fst, best_likelihood) = min(return_dict.items(), key=operator.itemgetter(1)) @@ -484,93 +424,253 @@ def _encode(self, far_path: str, encoder_path: str) -> None: encoder.write(encoder_path) -class PyniniTrainer: +class PyniniValidator(PyniniGenerator): + """ + Class for running validation for G2P model training + + Parameters + ---------- + word_list: list[str] + List of words to generate pronunciations + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator` + For parameters to generate pronunciations """ - Class for G2P trainer that uses Pynini functionality + + def __init__(self, word_list: list[str], **kwargs): + super().__init__(**kwargs) + self.word_list = word_list + + @property + def words_to_g2p(self) -> list[str]: + """Words to produce pronunciations""" + return self.word_list + + +class G2PTrainer(MfaWorker, TrainerMixin, PronunciationDictionaryMixin): + """ + Abstract mixin class for G2P training Parameters ---------- - dictionary: :class:`~montreal_forced_aligner.dictionary.PronunciationDictionary` - PronunciationDictionary to train from` - model_path: str - Output model path - train_config: TrainG2PConfig - Configuration for training G2P model - temp_directory: str, optional - Temporary directory, defaults to MFA's temporary directory + validation_proportion: float + Proportion of words to use as the validation set, defaults to 0.1, only used if ``evaluate`` is True + num_pronunciations: int + Number of pronunciations to generate + evaluate: bool + Flag for whether to evaluate the model performance on an validation set + + See Also + -------- + :class:`~montreal_forced_aligner.abc.MfaWorker` + For base MFA parameters + :class:`~montreal_forced_aligner.abc.TrainerMixin` + For base trainer parameters + :class:`~montreal_forced_aligner.dictionary.pronunciation.PronunciationDictionaryMixin` + For pronunciation dictionary parameters + + Attributes + ---------- + g2p_training_dictionary: dict[str, list[str]] + Dictionary of words to pronunciations to train from + g2p_validation_dictionary: dict[str, list[str]] + Dictionary of words to pronunciations to validate performance against + g2p_graphemes: set[str] + Set of graphemes in the training set + """ + + def __init__( + self, + validation_proportion: float = 0.1, + num_pronunciations: int = 1, + evaluate: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.evaluate = evaluate + self.validation_proportion = validation_proportion + self.num_pronunciations = num_pronunciations + self.g2p_training_dictionary = {} + self.g2p_validation_dictionary = None + self.g2p_graphemes = set() + + +class PyniniTrainer(G2PTrainer, TopLevelMfaWorker): + """ + Top-level G2P trainer that uses Pynini functionality + + Parameters + ---------- + order: int + Order of the ngram model, defaults to 7 + random_starts: int + Number of random starts to use in initialization, defaults to 25 + seed: int + Seed for randomization, defaults to 1917 + delta: float + Comparison/quantization delta for Baum-Welch training, defaults to 1/1024 + lr: float + Learning rate for Baum-Welch training, defaults to 1.0 + batch_size:int + Batch size for Baum-Welch training, defaults to 200 + num_iterations:int + Maximum number of iterations to use in Baum-Welch training, defaults to 10 + smoothing_method:str + Smoothing method for the ngram model, defaults to "kneser_ney" + pruning_method:str + Pruning method for pruning the ngram model, defaults to "relative_entropy" + model_size: int + Target number of ngrams for pruning, defaults to 1000000 input_epsilon: bool Flag for whether to allow for epsilon on input strings, default True output_epsilon: bool Flag for whether to allow for epsilon on output strings, default True - num_jobs: int - Number processes to use - verbose: bool - Flag for provide debug output to the terminal + fst_default_cache_gc: str + String to pass to OpenFst binaries for GC behavior + fst_default_cache_gc_limit: str + String to pass to OpenFst binaries for GC behavior + + See Also + -------- + :class:`~montreal_forced_aligner.g2p.trainer.G2PTrainer` + For base G2P training parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters """ def __init__( self, - dictionary: Dictionary, - model_path: str, - train_config: TrainG2PConfig, - temp_directory: Optional[str] = None, + order: int = 7, + random_starts: int = 25, + seed: int = 1917, + delta: float = 1 / 1024, + lr: float = 1.0, + batch_size: int = 200, + num_iterations: int = 10, + smoothing_method: str = "kneser_ney", + pruning_method: str = "relative_entropy", + model_size: int = 1000000, input_epsilon: bool = True, output_epsilon: bool = True, - num_jobs: int = 3, - verbose: bool = False, + fst_default_cache_gc="", + fst_default_cache_gc_limit="", + **kwargs, ): - super(PyniniTrainer, self).__init__() - if not temp_directory: - temp_directory = TEMP_DIR - self.temp_directory = os.path.join(temp_directory, "G2P") - self.train_config = train_config - self.verbose = verbose - self.models_temp_dir = os.path.join(temp_directory, "models", "G2P") - - self.name, _ = os.path.splitext(os.path.basename(model_path)) - self.temp_directory = os.path.join(self.temp_directory, self.name) - os.makedirs(self.temp_directory, exist_ok=True) - os.makedirs(self.models_temp_dir, exist_ok=True) - self.model_path = model_path - self.fst_path = os.path.join(self.temp_directory, "model.fst") - self.far_path = os.path.join(self.temp_directory, self.name + ".far") - self.encoder_path = os.path.join(self.temp_directory, self.name + ".enc") - self.dictionary = dictionary + super().__init__(**kwargs) + self.order = order + self.random_starts = random_starts + self.seed = seed + self.delta = delta + self.lr = lr + self.batch_size = batch_size + self.num_iterations = num_iterations + self.smoothing_method = smoothing_method + self.pruning_method = pruning_method + self.model_size = model_size self.input_epsilon = input_epsilon self.output_epsilon = output_epsilon - self.num_jobs = num_jobs - if not self.train_config.use_mp: - self.num_jobs = 1 - self.fst_default_cache_gc = "" - self.fst_default_cache_gc_limit = "" - self.train_log_path = os.path.join(self.temp_directory, "train.log") - - self.logger = logging.getLogger("g2p_trainer") - self.logger.setLevel(logging.DEBUG) + self.fst_default_cache_gc = fst_default_cache_gc + self.fst_default_cache_gc_limit = fst_default_cache_gc_limit + + @property + def data_source_identifier(self) -> str: + """Dictionary name""" + return self.dictionary_model.name + + @property + def data_directory(self) -> str: + """Data directory for trainer""" + return self.working_directory + + @property + def workflow_identifier(self) -> str: + """Identifier for Pynini G2P trainer""" + return "pynini_train_g2p" + + @property + def configuration(self) -> MetaDict: + """Configuration for G2P trainer""" + config = super().configuration + config.update({"dictionary_path": self.dictionary_model.path}) + return config + + def train_iteration(self) -> None: + """Train iteration, not used""" + pass - handler = logging.FileHandler(self.train_log_path) - handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - self.logger.addHandler(handler) - handler = logging.StreamHandler(sys.stdout) - if self.verbose: - handler.setLevel(logging.DEBUG) - else: - handler.setLevel(logging.INFO) - formatter = logging.Formatter("%(levelname)s - %(message)s") - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.model_log_path = os.path.join(self.temp_directory, "model.log") - self.sym_path = os.path.join(self.temp_directory, "phones.sym") - self.output_token_type = None + def setup(self) -> None: + """Setup for G2P training""" + if self.initialized: + return + os.makedirs(self.working_log_directory, exist_ok=True) + self.g2p_training_dictionary = self.words + self.initialize_training() + self.initialized = True + + @property + def architecture(self) -> str: + """Pynini""" + return "pynini" + + @property + def meta(self) -> MetaDict: + """Metadata for exported G2P model""" + from datetime import datetime + + from ..utils import get_mfa_version + + return { + "version": get_mfa_version(), + "architecture": self.architecture, + "train_date": str(datetime.now()), + "phones": sorted(self.non_silence_phones), + "graphemes": self.graphemes, + } + + @property + def input_path(self) -> str: + """Path to temporary file to store training data""" + return os.path.join(self.working_directory, "input.txt") + + def initialize_training(self) -> None: + """Initialize training G2P model""" + if self.evaluate: + word_dict = self.g2p_training_dictionary + words = sorted(word_dict.keys()) + total_items = len(words) + validation_items = int(total_items * self.validation_proportion) + validation_words = random.sample(words, validation_items) + self.g2p_training_dictionary = { + k: v for k, v in word_dict.items() if k not in validation_words + } + self.g2p_validation_dictionary = { + k: v for k, v in word_dict.items() if k in validation_words + } + for k in self.g2p_training_dictionary.keys(): + self.g2p_graphemes.update(k) + phones_path = os.path.join(self.working_directory, "phones_only.txt") + + with open(self.input_path, "w", encoding="utf8") as f2, open( + phones_path, "w", encoding="utf8" + ) as phonef: + for word, v in self.g2p_training_dictionary.items(): + if re.match(r"\W", word) is not None: + continue + for v2 in v: + f2.write(f"{word}\t{' '.join(v2['pronunciation'])}\n") + for v2 in v: + phonef.write(f"{' '.join(v2['pronunciation'])}\n") + subprocess.call(["ngramsymbols", phones_path, self.sym_path]) + os.remove(phones_path) def clean_up(self) -> None: """ Clean up temporary files """ - for name in os.listdir(self.temp_directory): - path = os.path.join(self.temp_directory, name) + for name in os.listdir(self.working_directory): + path = os.path.join(self.working_directory, name) if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True) elif not name.endswith(".log"): @@ -581,15 +681,17 @@ def generate_model(self) -> None: Generate an ngram G2P model from FAR strings """ assert os.path.exists(self.far_path) - with open(self.model_log_path, "w", encoding="utf8") as logf: - ngram_count_path = os.path.join(self.temp_directory, "ngram.count") - ngram_make_path = os.path.join(self.temp_directory, "ngram.make") - ngram_shrink_path = os.path.join(self.temp_directory, "ngram.shrink") + with open( + os.path.join(self.working_log_directory, "model.log"), "w", encoding="utf8" + ) as logf: + ngram_count_path = os.path.join(self.working_directory, "ngram.count") + ngram_make_path = os.path.join(self.working_directory, "ngram.make") + ngram_shrink_path = os.path.join(self.working_directory, "ngram.shrink") ngramcount_proc = subprocess.Popen( [ "ngramcount", "--require_symbols=false", - "--order={}".format(self.train_config.order), + f"--order={self.order}", self.far_path, ngram_count_path, ], @@ -600,7 +702,7 @@ def generate_model(self) -> None: ngrammake_proc = subprocess.Popen( [ "ngrammake", - "--method=" + self.train_config.smoothing_method, + f"--method={self.smoothing_method}", ngram_count_path, ngram_make_path, ], @@ -611,8 +713,8 @@ def generate_model(self) -> None: ngramshrink_proc = subprocess.Popen( [ "ngramshrink", - "--method=" + self.train_config.pruning_method, - "--target_number_of_ngrams={}".format(self.train_config.model_size), + f"--method={self.pruning_method}", + f"--target_number_of_ngrams={self.model_size}", ngram_make_path, ngram_shrink_path, ], @@ -630,110 +732,153 @@ def generate_model(self) -> None: os.remove(ngram_make_path) os.remove(ngram_shrink_path) - directory, filename = os.path.split(self.model_path) + def export_model(self, output_model_path: str) -> None: + """ + Export G2P model to specified path + + Parameters + ---------- + output_model_path:str + Path to export model + """ + directory, filename = os.path.split(output_model_path) basename, _ = os.path.splitext(filename) - model = G2PModel.empty(basename, root_directory=self.models_temp_dir) - model.add_meta_file(self.dictionary) - model.add_fst_model(self.temp_directory) - model.add_sym_path(self.temp_directory) + models_temp_dir = os.path.join(self.working_directory, "model_archive_tempo") + model = G2PModel.empty(basename, root_directory=models_temp_dir) + model.add_meta_file(self) + model.add_fst_model(self.working_directory) + model.add_sym_path(self.working_directory) if directory: os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(self.model_path) + basename, _ = os.path.splitext(output_model_path) model.dump(basename) model.clean_up() self.clean_up() - self.logger.info(f"Saved model to {self.model_path}") + self.logger.info(f"Saved model to {output_model_path}") + + @property + def fst_path(self): + """Internal temporary FST file""" + return os.path.join(self.working_directory, "model.fst") + + @property + def far_path(self): + """Internal temporary FAR file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.far") - def train(self, word_dict: Optional[Dict[str, DictionaryEntryType]] = None) -> None: + @property + def encoder_path(self): + """Internal temporary encoder file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.enc") + + @property + def sym_path(self): + """Internal temporary symbol file""" + return os.path.join(self.working_directory, "phones.sym") + + def train(self) -> None: """ Train a G2P model - - Parameters - ---------- - word_dict: Dict[str, DictionaryEntryType] - PronunciationDictionary of words to pronunciations, optional, defaults to the dictionary's - set of words """ - input_path = os.path.join(self.temp_directory, "input.txt") - phones_path = os.path.join(self.temp_directory, "phones_only.txt") - if word_dict is None: - word_dict = self.dictionary.actual_words - with open(input_path, "w", encoding="utf8") as f2, open( - phones_path, "w", encoding="utf8" - ) as phonef: - for word, v in word_dict.items(): - if re.match(r"\W", word) is not None: - continue - for v2 in v: - f2.write(f"{word}\t{' '.join(v2['pronunciation'])}\n") - for v2 in v: - phonef.write(f"{' '.join(v2['pronunciation'])}\n") - subprocess.call(["ngramsymbols", phones_path, self.sym_path]) - os.remove(phones_path) - aligner = PairNGramAligner(self.temp_directory) + aligner = PairNGramAligner(self.working_directory) input_token_type = "utf8" - self.output_token_type = pynini.SymbolTable.read_text(self.sym_path) + output_token_type = pynini.SymbolTable.read_text(self.sym_path) begin = time.time() if not os.path.exists(self.far_path) or not os.path.exists(self.encoder_path): aligner.align( - input_path, + self.input_path, self.far_path, self.encoder_path, input_token_type, self.input_epsilon, - self.output_token_type, + output_token_type, self.output_epsilon, self.num_jobs, - self.train_config.random_starts, - self.train_config.seed, - self.train_config.batch_size, - self.train_config.delta, - self.train_config.lr, - self.train_config.max_iterations, + self.random_starts, + self.seed, + self.batch_size, + self.delta, + self.lr, + self.num_iterations, self.fst_default_cache_gc, self.fst_default_cache_gc_limit, ) - self.logger.debug(f"Aligning {len(word_dict)} words took {time.time() - begin} seconds") + self.logger.debug( + f"Aligning {len(self.g2p_training_dictionary)} words took {time.time() - begin} seconds" + ) begin = time.time() self.generate_model() self.logger.debug( - f"Generating model for {len(word_dict)} words took {time.time() - begin} seconds" + f"Generating model for {len(self.g2p_training_dictionary)} words took {time.time() - begin} seconds" ) - def validate(self) -> None: + def finalize_training(self) -> None: + """Finalize training""" + if self.evaluate: + self.evaluate_g2p_model() + + def evaluate_g2p_model(self) -> None: """ Validate the G2P model against held out data """ + temp_model_path = os.path.join(self.working_log_directory, "g2p_model.zip") + self.export_model(temp_model_path) - word_dict = self.dictionary.actual_words - validation = 0.1 - words = word_dict.keys() - total_items = len(words) - validation_items = int(total_items * validation) - validation_words = random.sample(words, validation_items) - training_dictionary = {k: v for k, v in word_dict.items() if k not in validation_words} - validation_dictionary = {k: v for k, v in word_dict.items() if k in validation_words} - train_graphemes = set() - for k in word_dict.keys(): - train_graphemes.update(k) - self.train(training_dictionary) - - model = G2PModel(self.model_path, root_directory=self.temp_directory) - - gen = PyniniDictionaryGenerator( - model, - validation_dictionary.keys(), - temp_directory=os.path.join(self.temp_directory, "validation"), + gen = PyniniValidator( + g2p_model_path=temp_model_path, + word_list=list(self.g2p_validation_dictionary.keys()), + temporary_directory=os.path.join(self.working_directory, "validation"), num_jobs=self.num_jobs, - num_pronunciations=self.train_config.num_pronunciations, + num_pronunciations=self.num_pronunciations, ) - output = gen.generate() + output = gen.generate_pronunciations() + self.compute_validation_errors(output) + + def compute_validation_errors( + self, + hypothesis_values: dict[str, list[str]], + ): + """ + Computes validation errors + + Parameters + ---------- + hypothesis_values: dict[str, list[str]] + Hypothesis labels + """ begin = time.time() - wer, ler = compute_validation_errors(validation_dictionary, output, num_jobs=self.num_jobs) - print(f"WER:\t{wer:.2f}") - print(f"LER:\t{ler:.2f}") + # Word-level measures. + correct = 0 + incorrect = 0 + # Label-level measures. + total_edits = 0 + total_length = 0 + # Since the edit distance algorithm is quadratic, let's do this with + # multiprocessing. + with mp.Pool(self.num_jobs) as pool: + to_comp = [] + for word, hyp in hypothesis_values.items(): + g = self.g2p_validation_dictionary[word][0]["pronunciation"] + hyp = [h.split(" ") for h in hyp] + to_comp.append((g, hyp, True)) # Multiple hypotheses to compare + gen = pool.starmap(score, to_comp) + for (edits, length) in gen: + if edits == 0: + correct += 1 + else: + incorrect += 1 + total_edits += edits + total_length += length + for w, gold in self.g2p_validation_dictionary.items(): + if w not in hypothesis_values: + incorrect += 1 + gold = gold[0]["pronunciation"] + total_edits += len(gold) + total_length += len(gold) + wer = 100 * incorrect / (correct + incorrect) + ler = 100 * total_edits / total_length self.logger.info(f"WER:\t{wer:.2f}") self.logger.info(f"LER:\t{ler:.2f}") self.logger.debug( - f"Computation of errors for {len(validation_dictionary)} words took {time.time() - begin} seconds" + f"Computation of errors for {len(self.g2p_validation_dictionary)} words took {time.time() - begin} seconds" ) diff --git a/montreal_forced_aligner/helper.py b/montreal_forced_aligner/helper.py index f2563242..14aab486 100644 --- a/montreal_forced_aligner/helper.py +++ b/montreal_forced_aligner/helper.py @@ -5,15 +5,17 @@ """ from __future__ import annotations +import functools import sys import textwrap -from typing import TYPE_CHECKING, Any, Collection, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Optional, Type import numpy from colorama import Fore, Style if TYPE_CHECKING: - from .abc import CorpusMappingType, Labels, MetaDict, ScpType + from montreal_forced_aligner.abc import CorpusMappingType, Labels, MetaDict, ScpType + from montreal_forced_aligner.textgrid import CtmInterval __all__ = [ @@ -27,16 +29,54 @@ "score", "edit_distance", "output_mapping", + "parse_old_features", + "compare_labels", + "overlap_scoring", + "align_phones", ] +def parse_old_features(config: MetaDict) -> MetaDict: + """ + Backwards compatibility function to parse old feature configuration blocks + + Parameters + ---------- + config: dict[str, Any] + Configuration parameters + + Returns + ------- + dict[str, Any] + Up to date versions of feature blocks + """ + feature_key_remapping = { + "type": "feature_type", + "deltas": "uses_deltas", + "lda": "uses_splices", + "fmllr": "uses_speaker_adaptation", + } + if "features" in config: + + for key, new_key in feature_key_remapping.items(): + if key in config["features"]: + config["features"][new_key] = config["features"][key] + del config["features"][key] + else: + for key, new_key in feature_key_remapping.items(): + if key in config: + config[new_key] = config[key] + del config[key] + return config + + class TerminalPrinter: """ Helper class to output colorized text Attributes ---------- - colors: Dict[str,str] + colors: dict[str, str] Mapping of color names to terminal codes in colorama (or empty strings if the global terminal_colors flag is set to False) """ @@ -65,13 +105,13 @@ def __init__(self): self.colors["reset"] = Style.RESET_ALL self.colors["normal"] = Style.NORMAL - def colorize(self, text: str, color: str) -> str: + def colorize(self, text: Any, color: str) -> str: """ Colorize a string Parameters ---------- - text: str + text: Any Text to colorize color: str Colorama code or empty string to wrap the text @@ -116,7 +156,7 @@ def print_config(self, configuration: MetaDict) -> None: Parameters ---------- - configuration: :class:`~montreal_forced_aligner.abc.MetaDict` + configuration: dict[str, Any] Configuration to print """ for k, v in configuration.items(): @@ -165,7 +205,7 @@ def print_information_line( if isinstance(value, (list, tuple, set)): value = comma_join([self.colorize(x, value_color) for x in sorted(value)]) else: - value = self.colorize(value, value_color) + value = self.colorize(str(value), value_color) indent = (" " * level) + "-" subsequent_indent = " " * (level + 1) if key: @@ -177,13 +217,14 @@ def print_information_line( print(wrapper.fill(f"{self.colorize(key, key_color)} {value}")) -def comma_join(sequence: Collection[Any]) -> str: +def comma_join(sequence: list[Any]) -> str: """ - Helper function to combine a list into a human-readable expression with commas and a final "and" separator + Helper function to combine a list into a human-readable expression with commas and a + final "and" separator Parameters ---------- - sequence: Collection[Any] + sequence: list[Any] Items to join together into a list Returns @@ -255,6 +296,14 @@ def output_mapping(mapping: CorpusMappingType, path: str, skip_safe: bool = Fals """ Helper function to save mapping information (i.e., utt2spk) in Kaldi scp format + CorpusMappingType is either a dictionary of key to value for + one-to-one mapping case and a dictionary of key to list of values for one-to-many case. + + See Also + -------- + :func:`~montreal_forced_aligner.helper.save_scp` + For another function that saves SCPs from lists + Parameters ---------- mapping: CorpusMappingType @@ -280,7 +329,15 @@ def save_scp( scp: ScpType, path: str, sort: Optional[bool] = True, multiline: Optional[bool] = False ) -> None: """ - Helper function to save an arbitrary SCP + Helper function to save an arbitrary SCP. + + ScpType is either a list of tuples (str, str) for one-to-one mapping files or + a list of tuples (str, list) for one-to-many mappings. + + See Also + -------- + :kaldi_docs:`io#io_sec_scp_details` + For more information on the SCP format Parameters ---------- @@ -313,7 +370,16 @@ def load_scp(path: str, data_type: Optional[Type] = str) -> CorpusMappingType: """ Load a Kaldi script file (.scp) - See http://kaldi-asr.org/doc/io.html#io_sec_scp_details for more information + Scp files in Kaldi can either be one-to-one or one-to-many, with the first element separated by + whitespace as the key and the remaining whitespace-delimited elements the values. + + Returns a dictionary of key to value for + one-to-one mapping case and a dictionary of key to list of values for one-to-many case. + + See Also + -------- + :kaldi_docs:`io#io_sec_scp_details` + For more information on the SCP format Parameters ---------- @@ -324,9 +390,9 @@ def load_scp(path: str, data_type: Optional[Type] = str) -> CorpusMappingType: Returns ------- - dict - PronunciationDictionary where the keys are the first couple and the values are all - other columns in the script file + CorpusMappingType + Dictionary where the keys are the first column and the values are all + other columns in the scp file """ scp = {} @@ -351,15 +417,16 @@ def edit_distance(x: Labels, y: Labels) -> int: """ Compute edit distance between two sets of labels - For a more expressive version of the same, see: - - https://gist.github.com/kylebgorman/8034009 + See Also + -------- + `https://gist.github.com/kylebgorman/8034009 `_ + For a more expressive version of this function Parameters ---------- x: Labels First sequence to compare - y: Lables + y: Labels Second sequence to compare Returns @@ -384,7 +451,7 @@ def edit_distance(x: Labels, y: Labels) -> int: return int(table[-1][-1]) -def score(gold: Labels, hypo: (Labels, List)) -> Tuple[int, int]: +def score(gold: Labels, hypo: Labels, multiple_hypotheses=False) -> tuple[int, int]: """ Computes sufficient statistics for LER calculation. @@ -394,6 +461,8 @@ def score(gold: Labels, hypo: (Labels, List)) -> Tuple[int, int]: The reference labels hypo: Labels The hypothesized labels + multiple_hypotheses: bool + Flag for whether the hypotheses contain multiple Returns ------- @@ -402,7 +471,7 @@ def score(gold: Labels, hypo: (Labels, List)) -> Tuple[int, int]: int Length of the gold labels """ - if isinstance(hypo, list): + if multiple_hypotheses: edits = 100000 for h in hypo: e = edit_distance(gold, h) @@ -413,3 +482,133 @@ def score(gold: Labels, hypo: (Labels, List)) -> Tuple[int, int]: else: edits = edit_distance(gold, hypo) return edits, len(gold) + + +def compare_labels(ref: str, test: str, mapping: Optional[dict[str, str]] = None) -> int: + """ + + Parameters + ---------- + ref: str + test: str + mapping: Optional[dict[str, str]] + + Returns + ------- + int + 0 if labels match or they're in mapping, 2 otherwise + """ + if ref == test: + return 0 + if mapping is not None and test in mapping and mapping[test] == ref: + return 0 + ref = ref.lower() + test = test.lower() + if ref == test: + return 0 + return 2 + + +def overlap_scoring( + first_element: CtmInterval, + second_element: CtmInterval, + mapping: Optional[dict[str, str]] = None, +) -> float: + r""" + Method to calculate overlap scoring + + .. math:: + + Score = -(\lvert begin_{1} - begin_{2} \rvert + \lvert end_{1} - end_{2} \rvert + \begin{cases} + 0, & if label_{1} = label_{2} \\ + 2, & otherwise + \end{cases}) + + See Also + -------- + `Blog post `_ + For a detailed example that using this metric + + Parameters + ---------- + first_element: :class:`~montreal_forced_aligner.textgrid.CtmInterval` + First CTM interval to compare + second_element: :class:`~montreal_forced_aligner.textgrid.CtmInterval` + Second CTM interval + mapping: Optional[dict[str, str]] + Optional mapping of phones to treat as matches even if they have different symbols + + Returns + ------- + float + Score calculated as the negative sum of the absolute different in begin timestamps, absolute difference in end + timestamps and the label score + """ + begin_diff = abs(first_element.begin - second_element.begin) + end_diff = abs(first_element.end - second_element.end) + label_diff = compare_labels(first_element.label, second_element.label, mapping) + return -1 * (begin_diff + end_diff + label_diff) + + +def align_phones( + ref: list[CtmInterval], + test: list[CtmInterval], + silence_phones: set[str], + custom_mapping: Optional[dict[str, str]] = None, +) -> tuple[Optional[float], Optional[int], Optional[int]]: + """ + Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for + different phone labels to be scored as if they're the same phone + + Parameters + ---------- + ref: list[:class:`~montreal_forced_aligner.textgrid.CtmInterval`] + List of CTM intervals as reference + test: list[:class:`~montreal_forced_aligner.textgrid.CtmInterval`] + List of CTM intervals to compare to reference + silence_phones: set[str] + Set of silence phones (these are ignored in the final calculation) + custom_mapping: dict[str, str], optional + Optional mapping of phones to treat as matches even if they have different symbols + + Returns + ------- + float + Score based on the average amount of overlap in phone intervals + int + Number of insertions + int + Number of deletions + """ + try: + from Bio import pairwise2 + except ImportError: + return None, None, None + if custom_mapping is None: + score_func = overlap_scoring + else: + score_func = functools.partial(overlap_scoring, mapping=custom_mapping) + alignments = pairwise2.align.globalcs( + ref, test, score_func, -5, -5, gap_char=["-"], one_alignment_only=True + ) + overlap_count = 0 + overlap_sum = 0 + num_insertions = 0 + num_deletions = 0 + for a in alignments: + for i, sa in enumerate(a.seqA): + sb = a.seqB[i] + if sa == "-": + if sb.label not in silence_phones: + num_insertions += 1 + else: + continue + elif sb == "-": + if sa.label not in silence_phones: + num_deletions += 1 + else: + continue + else: + overlap_sum += abs(sa.begin - sb.begin) + abs(sa.end - sb.end) + overlap_count += 1 + return overlap_sum / overlap_count, num_insertions, num_deletions diff --git a/montreal_forced_aligner/ivector/__init__.py b/montreal_forced_aligner/ivector/__init__.py new file mode 100644 index 00000000..b3943fd9 --- /dev/null +++ b/montreal_forced_aligner/ivector/__init__.py @@ -0,0 +1,9 @@ +"""Module for ivector extractor training""" + +from montreal_forced_aligner.ivector.trainer import ( + DubmTrainer, + IvectorTrainer, + TrainableIvectorExtractor, +) + +__all__ = ["trainer", "DubmTrainer", "IvectorTrainer", "TrainableIvectorExtractor"] diff --git a/montreal_forced_aligner/ivector/trainer.py b/montreal_forced_aligner/ivector/trainer.py new file mode 100644 index 00000000..bbc09c2b --- /dev/null +++ b/montreal_forced_aligner/ivector/trainer.py @@ -0,0 +1,1127 @@ +"""Class definition for TrainableIvectorExtractor""" +from __future__ import annotations + +import os +import shutil +import subprocess +import time +from typing import TYPE_CHECKING, Any, NamedTuple, Optional + +import yaml + +from ..abc import MetaDict, ModelExporterMixin, TopLevelMfaWorker +from ..acoustic_modeling.base import AcousticModelTrainingMixin +from ..corpus.features import IvectorConfigMixin +from ..corpus.ivector_corpus import IvectorCorpusMixin +from ..exceptions import ConfigError, KaldiProcessingError +from ..models import IvectorExtractorModel +from ..utils import log_kaldi_errors, parse_logs, run_mp, run_non_mp, thirdparty_binary + +if TYPE_CHECKING: + from argparse import Namespace + +__all__ = [ + "TrainableIvectorExtractor", + "DubmTrainer", + "IvectorTrainer", + "IvectorModelTrainingMixin", + "acc_ivector_stats_func", +] + + +class IvectorModelTrainingMixin(AcousticModelTrainingMixin): + """ + Abstract mixin for training ivector extractor models + + See Also + -------- + :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin` + For acoustic model training parsing parameters + """ + + def export_model(self, output_model_path: str) -> None: + """ + Output IvectorExtractor model + + Parameters + ---------- + output_model_path : str + Path to save ivector extractor model + """ + directory, filename = os.path.split(output_model_path) + basename, _ = os.path.splitext(filename) + ivector_extractor = IvectorExtractorModel.empty(basename, self.working_log_directory) + ivector_extractor.add_meta_file(self) + ivector_extractor.add_model(self.working_directory) + if directory: + os.makedirs(directory, exist_ok=True) + basename, _ = os.path.splitext(output_model_path) + ivector_extractor.dump(basename) + + +class GmmGselectArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.ivector.trainer.gmm_gselect_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ivector_options: MetaDict + dubm_model: str + gselect_paths: dict[str, str] + + +class AccGlobalStatsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.ivector.trainer.acc_global_stats_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ivector_options: MetaDict + gselect_paths: dict[str, str] + acc_paths: dict[str, str] + dubm_path: str + + +class GaussToPostArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.ivector.trainer.gauss_to_post_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ivector_options: MetaDict + post_paths: dict[str, str] + dubm_path: str + + +class AccIvectorStatsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.ivector.trainer.acc_ivector_stats_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + ivector_options: MetaDict + ie_path: str + post_paths: dict[str, str] + acc_init_paths: dict[str, str] + + +def gmm_gselect_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + dubm_options: MetaDict, + dubm_path: str, + gselect_paths: dict[str, str], +) -> None: + """ + Multiprocessing function for selecting GMM indices. + + See Also + -------- + :meth:`.DubmTrainer.gmm_gselect` + Main function that calls this function in parallel + :meth:`.DubmTrainer.gmm_gselect_arguments` + Job method for generating arguments for this function + :kaldi_src:`subsample-feats` + Relevant Kaldi binary + :kaldi_src:`gmm-gselect` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + dubm_options: dict[str, Any] + Options for DUBM training + dubm_path: str + Path to the DUBM file + gselect_paths: dict[str, str] + Dictionary of gselect archives per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + gselect_path = gselect_paths[dict_name] + subsample_feats_proc = subprocess.Popen( + [ + thirdparty_binary("subsample-feats"), + f"--n={dubm_options['subsample']}", + feature_string, + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + + gselect_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-gselect"), + f"--n={dubm_options['num_gselect']}", + dubm_path, + "ark:-", + f"ark:{gselect_path}", + ], + stdin=subsample_feats_proc.stdout, + stderr=log_file, + env=os.environ, + ) + gselect_proc.communicate() + + +def gauss_to_post_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ivector_options: MetaDict, + post_paths: dict[str, str], + dubm_path: str, +): + """ + Multiprocessing function to get posteriors during UBM training. + + See Also + -------- + :meth:`.IvectorTrainer.gauss_to_post` + Main function that calls this function in parallel + :meth:`.IvectorTrainer.gauss_to_post_arguments` + Job method for generating arguments for this function + :kaldi_src:`subsample-feats` + Relevant Kaldi binary + :kaldi_src:`gmm-global-get-post` + Relevant Kaldi binary + :kaldi_src:`scale-post` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ivector_options: dict[str, Any] + Options for ivector extractor training + post_paths: dict[str, str] + Dictionary of posterior archives per dictionary name + dubm_path: str + Path to the DUBM file + """ + modified_posterior_scale = ivector_options["posterior_scale"] * ivector_options["subsample"] + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + post_path = post_paths[dict_name] + subsample_feats_proc = subprocess.Popen( + [ + thirdparty_binary("subsample-feats"), + f"--n={ivector_options['subsample']}", + feature_string, + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + gmm_global_get_post_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-global-get-post"), + f"--n={ivector_options['num_gselect']}", + f"--min-post={ivector_options['min_post']}", + dubm_path, + "ark:-", + "ark:-", + ], + stdout=subprocess.PIPE, + stdin=subsample_feats_proc.stdout, + stderr=log_file, + env=os.environ, + ) + scale_post_proc = subprocess.Popen( + [ + thirdparty_binary("scale-post"), + "ark:-", + str(modified_posterior_scale), + f"ark:{post_path}", + ], + stdin=gmm_global_get_post_proc.stdout, + stderr=log_file, + env=os.environ, + ) + scale_post_proc.communicate() + + +def acc_global_stats_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + dubm_options: MetaDict, + gselect_paths: dict[str, str], + acc_paths: dict[str, str], + dubm_path: str, +) -> None: + """ + Multiprocessing function for accumulating global model stats. + + See Also + -------- + :meth:`.DubmTrainer.acc_global_stats` + Main function that calls this function in parallel + :meth:`.DubmTrainer.acc_global_stats_arguments` + Job method for generating arguments for this function + :kaldi_src:`subsample-feats` + Relevant Kaldi binary + :kaldi_src:`gmm-global-acc-stats` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + dubm_options: dict[str, Any] + Options for DUBM training + gselect_paths: dict[str, str] + Dictionary of gselect archives per dictionary name + acc_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + dubm_path: str + Path to the DUBM file + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + gselect_path = gselect_paths[dict_name] + acc_path = acc_paths[dict_name] + subsample_feats_proc = subprocess.Popen( + [ + thirdparty_binary("subsample-feats"), + f"--n={dubm_options['subsample']}", + feature_string, + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + gmm_global_acc_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-global-acc-stats"), + f"--gselect=ark:{gselect_path}", + dubm_path, + "ark:-", + acc_path, + ], + stderr=log_file, + stdin=subsample_feats_proc.stdout, + env=os.environ, + ) + gmm_global_acc_proc.communicate() + + +def acc_ivector_stats_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + ivector_options: MetaDict, + ie_path: str, + post_paths: dict[str, str], + acc_init_paths: dict[str, str], +) -> None: + """ + Multiprocessing function that accumulates stats for ivector training. + + See Also + -------- + :meth:`.IvectorTrainer.acc_ivector_stats` + Main function that calls this function in parallel + :meth:`.IvectorTrainer.acc_ivector_stats_arguments` + Job method for generating arguments for this function + :kaldi_src:`subsample-feats` + Relevant Kaldi binary + :kaldi_src:`ivector-extractor-acc-stats` + Relevant Kaldi binary + + Parameters + ---------- + log_path: str + Path to save log output + dictionaries: list[str] + List of dictionary names + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + ivector_options: dict[str, Any] + Options for ivector extractor training + ie_path: str + Path to the ivector extractor file + post_paths: dict[str, str] + Dictionary of posterior archives per dictionary name + acc_init_paths: dict[str, str] + Dictionary of accumulated stats files per dictionary name + """ + with open(log_path, "w", encoding="utf8") as log_file: + for dict_name in dictionaries: + feature_string = feature_strings[dict_name] + post_path = post_paths[dict_name] + acc_init_path = acc_init_paths[dict_name] + subsample_feats_proc = subprocess.Popen( + [ + thirdparty_binary("subsample-feats"), + f"--n={ivector_options['subsample']}", + feature_string, + "ark:-", + ], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) + acc_stats_proc = subprocess.Popen( + [ + thirdparty_binary("ivector-extractor-acc-stats"), + "--num-threads=1", + ie_path, + "ark:-", + f"ark:{post_path}", + acc_init_path, + ], + stdin=subsample_feats_proc.stdout, + stderr=log_file, + env=os.environ, + ) + acc_stats_proc.communicate() + + +class DubmTrainer(IvectorModelTrainingMixin): + """ + Trainer for diagonal universal background models + + Parameters + ---------- + num_iterations : int + Number of training iterations to perform, defaults to 4 + num_gselect: int + Number of Gaussian-selection indices to use while training + subsample: int + Subsample factor for feature frames, defaults to 5 + num_frames:int + Number of frames to keep in memory for initialization, defaults to 500000 + num_gaussians:int + Number of gaussians to use for DUBM training, defaults to 256 + num_iterations_init:int + Number of iteration to use when initializing UBM, defaults to 20 + initial_gaussian_proportion:float + Proportion of total gaussians to use initially, defaults to 0.5 + min_gaussian_weight: float + Defaults to 0.0001 + remove_low_count_gaussians: bool + Flag for removing low count gaussians in the final round of training, defaults to True + + See Also + -------- + :class:`~montreal_forced_aligner.ivector.trainer.IvectorModelTrainingMixin` + For base ivector training parameters + """ + + def __init__( + self, + num_iterations: int = 4, + num_gselect: int = 30, + subsample: int = 5, + num_frames: int = 500000, + num_gaussians: int = 256, + num_iterations_init: int = 20, + initial_gaussian_proportion: float = 0.5, + min_gaussian_weight: float = 0.0001, + remove_low_count_gaussians: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_iterations = num_iterations + self.subsample = subsample + self.num_gselect = num_gselect + self.num_frames = num_frames + self.num_gaussians = num_gaussians + self.num_iterations_init = num_iterations_init + self.initial_gaussian_proportion = initial_gaussian_proportion + self.min_gaussian_weight = min_gaussian_weight + self.remove_low_count_gaussians = remove_low_count_gaussians + + def compute_calculated_properties(self) -> None: + pass + + @property + def train_type(self) -> str: + """Training identifier""" + return "dubm" + + @property + def dubm_options(self): + """Options for DUBM training""" + return {"subsample": self.subsample, "num_gselect": self.num_gselect} + + def gmm_gselect_arguments(self) -> list[GmmGselectArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.ivector.trainer.gmm_gselect_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.ivector.trainer.GmmGselectArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + GmmGselectArguments( + os.path.join(self.working_log_directory, f"gmm_gselect.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.dubm_options, + self.model_path, + j.construct_path_dictionary(self.working_directory, "gselect", "ark"), + ) + for j in self.jobs + ] + + def acc_global_stats_arguments( + self, + ) -> list[AccGlobalStatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.ivector.trainer.acc_global_stats_func` + + + Returns + ------- + list[:class:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + AccGlobalStatsArguments( + os.path.join( + self.working_log_directory, + f"acc_global_stats.{self.iteration}.{j.name}.log", + ), + j.current_dictionary_names, + feat_strings[j.name], + self.dubm_options, + j.construct_path_dictionary(self.working_directory, "gselect", "ark"), + j.construct_path_dictionary( + self.working_directory, f"global.{self.iteration}", "acc" + ), + self.model_path, + ) + for j in self.jobs + ] + + def gmm_gselect(self) -> None: + """ + Multiprocessing function that stores Gaussian selection indices on disk + + See Also + -------- + :func:`~montreal_forced_aligner.ivector.trainer.gmm_gselect_func` + Multiprocessing helper function for each job + :meth:`.DubmTrainer.gmm_gselect_arguments` + Job method for generating arguments for the helper function + :kaldi_steps:`train_diag_ubm` + Reference Kaldi script + + """ + jobs = self.gmm_gselect_arguments() + if self.use_mp: + run_mp(gmm_gselect_func, jobs, self.working_log_directory) + else: + run_non_mp(gmm_gselect_func, jobs, self.working_log_directory) + + def _trainer_initialization(self, initial_alignment_directory: Optional[str] = None) -> None: + """DUBM training initialization""" + # Initialize model from E-M in memory + log_directory = os.path.join(self.working_directory, "log") + if initial_alignment_directory and os.path.exists(initial_alignment_directory): + jobs = self.align_arguments() + for j in jobs: + for p in j.ali_paths.values(): + shutil.copyfile( + p.replace(self.working_directory, initial_alignment_directory), p + ) + shutil.copyfile( + os.path.join(initial_alignment_directory, "final.mdl"), + os.path.join(self.working_directory, "final.mdl"), + ) + num_gauss_init = int(self.initial_gaussian_proportion * int(self.num_gaussians)) + log_path = os.path.join(log_directory, "gmm_init.log") + all_feats_path = os.path.join(self.corpus_output_directory, "feats.scp") + feature_string = self.construct_base_feature_string(all_feats=True) + with open(all_feats_path, "w") as outf: + for i in self.jobs: + feat_paths = i.construct_path_dictionary(self.data_directory, "feats", "scp") + for p in feat_paths.values(): + with open(p) as inf: + for line in inf: + outf.write(line) + self.iteration = 1 + with open(log_path, "w") as log_file: + gmm_init_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-global-init-from-feats"), + f"--num-threads={self.worker.num_jobs}", + f"--num-frames={self.num_frames}", + f"--num_gauss={self.num_gaussians}", + f"--num_gauss_init={num_gauss_init}", + f"--num_iters={self.num_iterations_init}", + feature_string, + self.model_path, + ], + stderr=log_file, + ) + gmm_init_proc.communicate() + # Store Gaussian selection indices on disk + self.gmm_gselect() + parse_logs(log_directory) + + def acc_global_stats(self) -> None: + """ + Multiprocessing function that accumulates global GMM stats + + See Also + -------- + :func:`~montreal_forced_aligner.ivector.trainer.acc_global_stats_func` + Multiprocessing helper function for each job + :meth:`.DubmTrainer.acc_global_stats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`gmm-global-sum-accs` + Relevant Kaldi binary + :kaldi_steps:`train_diag_ubm` + Reference Kaldi script + + """ + jobs = self.acc_global_stats_arguments() + if self.use_mp: + run_mp(acc_global_stats_func, jobs, self.working_log_directory) + else: + run_non_mp(acc_global_stats_func, jobs, self.working_log_directory) + + # Don't remove low-count Gaussians till the last tier, + # or gselect info won't be valid anymore + if self.iteration < self.num_iterations: + opt = "--remove-low-count-gaussians=false" + else: + opt = f"--remove-low-count-gaussians={self.remove_low_count_gaussians}" + log_path = os.path.join(self.working_log_directory, f"update.{self.iteration}.log") + with open(log_path, "w") as log_file: + acc_files = [] + for j in jobs: + acc_files.extend(j.acc_paths.values()) + sum_proc = subprocess.Popen( + [thirdparty_binary("gmm-global-sum-accs"), "-"] + acc_files, + stderr=log_file, + stdout=subprocess.PIPE, + env=os.environ, + ) + gmm_global_est_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-global-est"), + opt, + f"--min-gaussian-weight={self.min_gaussian_weight}", + self.model_path, + "-", + self.next_model_path, + ], + stderr=log_file, + stdin=sum_proc.stdout, + env=os.environ, + ) + gmm_global_est_proc.communicate() + # Clean up + if not self.debug: + for p in acc_files: + os.remove(p) + + @property + def exported_model_path(self) -> str: + """Temporary model path to save intermediate model""" + return os.path.join(self.working_log_directory, "dubm_model.zip") + + def train_iteration(self) -> None: + """ + Run an iteration of UBM training + """ + # Accumulate stats + self.acc_global_stats() + self.iteration += 1 + + def finalize_training(self) -> None: + """Finalize DUBM training""" + final_dubm_path = os.path.join(self.working_directory, "final.dubm") + shutil.copy( + os.path.join(self.working_directory, f"{self.num_iterations+1}.dubm"), + final_dubm_path, + ) + self.export_model(self.exported_model_path) + self.training_complete = True + + @property + def model_path(self) -> str: + """Current iteration's DUBM model path""" + if self.training_complete: + return os.path.join(self.working_directory, "final.dubm") + return os.path.join(self.working_directory, f"{self.iteration}.dubm") + + @property + def next_model_path(self) -> str: + """Next iteration's DUBM model path""" + if self.training_complete: + return os.path.join(self.working_directory, "final.dubm") + return os.path.join(self.working_directory, f"{self.iteration + 1}.dubm") + + +class IvectorTrainer(IvectorModelTrainingMixin, IvectorConfigMixin): + """ + Trainer for a block of ivector extractor training + + Parameters + ---------- + num_iterations: int + Number of iterations, defaults to 10 + subsample: int + Subsample factor for feature frames, defaults to 5 + gaussian_min_count: int + + See Also + -------- + :class:`~montreal_forced_aligner.ivector.trainer.IvectorModelTrainingMixin` + For base parameters for ivector training + :class:`~montreal_forced_aligner.corpus.features.IvectorConfigMixin` + For parameters for ivector feature generation + + """ + + def __init__( + self, num_iterations: int = 10, subsample: int = 5, gaussian_min_count: int = 100, **kwargs + ): + super().__init__(**kwargs) + self.subsample = subsample + self.num_iterations = num_iterations + self.gaussian_min_count = gaussian_min_count + + def compute_calculated_properties(self) -> None: + pass + + @property + def exported_model_path(self) -> str: + """Temporary directory path that trainer will save ivector extractor model""" + return os.path.join(self.working_log_directory, "ivector_model.zip") + + def acc_ivector_stats_arguments(self) -> list[AccIvectorStatsArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.ivector.trainer.acc_ivector_stats_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + arguments = [ + AccIvectorStatsArguments( + os.path.join(self.working_log_directory, f"ivector_acc.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.ivector_options, + self.ie_path, + j.construct_path_dictionary(self.working_directory, "post", "ark"), + j.construct_path_dictionary(self.working_directory, "ivector", "acc"), + ) + for j in self.jobs + ] + + return arguments + + def _trainer_initialization(self) -> None: + """Ivector extractor training initialization""" + self.iteration = 1 + self.training_complete = False + # Initialize job_name-vector extractor + log_directory = os.path.join(self.working_directory, "log") + log_path = os.path.join(log_directory, "init.log") + diag_ubm_path = os.path.join(self.working_directory, "final.dubm") + + full_ubm_path = os.path.join(self.working_directory, "final.ubm") + with open(log_path, "w") as log_file: + subprocess.call( + [thirdparty_binary("gmm-global-to-fgmm"), diag_ubm_path, full_ubm_path], + stderr=log_file, + ) + subprocess.call( + [ + thirdparty_binary("ivector-extractor-init"), + f"--ivector-dim={self.ivector_dimension}", + "--use-weights=false", + full_ubm_path, + self.ie_path, + ], + stderr=log_file, + ) + + # Do Gaussian selection and posterior extraction + self.gauss_to_post() + parse_logs(log_directory) + + def gauss_to_post_arguments(self) -> list[GaussToPostArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.ivector.trainer.gauss_to_post_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.ivector.trainer.GaussToPostArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + GaussToPostArguments( + os.path.join(self.working_log_directory, f"gauss_to_post.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.ivector_options, + j.construct_path_dictionary(self.working_directory, "post", "ark"), + self.dubm_path, + ) + for j in self.jobs + ] + + def gauss_to_post(self) -> None: + """ + Multiprocessing function that does Gaussian selection and posterior extraction + + See Also + -------- + :func:`~montreal_forced_aligner.ivector.trainer.gauss_to_post_func` + Multiprocessing helper function for each job + :meth:`.IvectorTrainer.gauss_to_post_arguments` + Job method for generating arguments for the helper function + :kaldi_steps_sid:`train_ivector_extractor` + Reference Kaldi script + """ + jobs = self.gauss_to_post_arguments() + if self.use_mp: + run_mp(gauss_to_post_func, jobs, self.working_log_directory) + else: + run_non_mp(gauss_to_post_func, jobs, self.working_log_directory) + + @property + def train_type(self) -> str: + """Training identifier""" + return "ivector" + + @property + def ivector_options(self) -> MetaDict: + """Options for ivector training and extracting""" + options = super().ivector_options + options["subsample"] = self.subsample + return options + + @property + def meta(self) -> MetaDict: + """Metadata information for ivector extractor models""" + from ..utils import get_mfa_version + + return { + "version": get_mfa_version(), + "ivector_dimension": self.ivector_dimension, + "num_gselect": self.num_gselect, + "min_post": self.min_post, + "posterior_scale": self.posterior_scale, + "features": self.feature_options, + } + + @property + def ie_path(self) -> str: + """Current ivector extractor model path""" + if self.training_complete: + return os.path.join(self.working_directory, "final.ie") + return os.path.join(self.working_directory, f"{self.iteration}.ie") + + @property + def next_ie_path(self) -> str: + """Next iteration's ivector extractor model path""" + if self.training_complete: + return os.path.join(self.working_directory, "final.ie") + return os.path.join(self.working_directory, f"{self.iteration + 1}.ie") + + @property + def dubm_path(self) -> str: + """DUBM model path""" + return os.path.join(self.working_directory, "final.dubm") + + def acc_ivector_stats(self) -> None: + """ + Multiprocessing function that accumulates ivector extraction stats. + + See Also + -------- + :func:`~montreal_forced_aligner.ivector.trainer.acc_ivector_stats_func` + Multiprocessing helper function for each job + :meth:`.IvectorTrainer.acc_ivector_stats_arguments` + Job method for generating arguments for the helper function + :kaldi_src:`ivector-extractor-sum-accs` + Relevant Kaldi binary + :kaldi_src:`ivector-extractor-est` + Relevant Kaldi binary + :kaldi_steps_sid:`train_ivector_extractor` + Reference Kaldi script + """ + + jobs = self.acc_ivector_stats_arguments() + if self.use_mp: + run_mp(acc_ivector_stats_func, jobs, self.working_log_directory) + else: + run_non_mp(acc_ivector_stats_func, jobs, self.working_log_directory) + + log_path = os.path.join(self.working_log_directory, f"sum_acc.{self.iteration}.log") + acc_path = os.path.join(self.working_directory, f"acc.{self.iteration}") + with open(log_path, "w", encoding="utf8") as log_file: + accinits = [] + for j in jobs: + accinits.extend(j.acc_init_paths.values()) + sum_accs_proc = subprocess.Popen( + [thirdparty_binary("ivector-extractor-sum-accs"), "--parallel=true"] + + accinits + + [acc_path], + stderr=log_file, + env=os.environ, + ) + + sum_accs_proc.communicate() + # clean up + for p in accinits: + os.remove(p) + # Est extractor + log_path = os.path.join(self.working_log_directory, f"update.{self.iteration}.log") + with open(log_path, "w") as log_file: + extractor_est_proc = subprocess.Popen( + [ + thirdparty_binary("ivector-extractor-est"), + f"--num-threads={len(self.jobs)}", + f"--gaussian-min-count={self.gaussian_min_count}", + self.ie_path, + os.path.join(self.working_directory, f"acc.{self.iteration}"), + self.next_ie_path, + ], + stderr=log_file, + env=os.environ, + ) + extractor_est_proc.communicate() + + def train_iteration(self): + """ + Run an iteration of training + """ + if os.path.exists(self.next_ie_path): + self.iteration += 1 + return + # Accumulate stats and sum + self.acc_ivector_stats() + + self.iteration += 1 + + def finalize_training(self): + """ + Finalize ivector extractor training + """ + # Rename to final + shutil.copy( + os.path.join(self.working_directory, f"{self.num_iterations}.ie"), + os.path.join(self.working_directory, "final.ie"), + ) + self.training_complete = True + + +class TrainableIvectorExtractor(IvectorCorpusMixin, TopLevelMfaWorker, ModelExporterMixin): + """ + Trainer for ivector extractor models + + Parameters + ---------- + training_configuration: list[tuple[str, dict[str, Any]]] + Training configurations to use, defaults to a round of dubm training followed by ivector training + + See Also + -------- + :class:`~montreal_forced_aligner.corpus.ivector_corpus.IvectorCorpusMixin` + For parameters to parse corpora using ivector features + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + :class:`~montreal_forced_aligner.abc.ModelExporterMixin` + For model export parameters + """ + + def __init__(self, training_configuration: list[tuple[str, dict[str, Any]]] = None, **kwargs): + self.param_dict = { + k: v + for k, v in kwargs.items() + if not k.endswith("_directory") + and not k.endswith("_path") + and k not in ["clean", "num_jobs", "speaker_characters"] + } + self.final_identifier = None + super().__init__(**kwargs) + os.makedirs(self.output_directory, exist_ok=True) + self.training_configs: dict[str, AcousticModelTrainingMixin] = {} + self.current_model = None + if training_configuration is None: + training_configuration = [("dubm", {}), ("ivector", {})] + for k, v in training_configuration: + self.add_config(k, v) + + def setup(self) -> None: + """Setup ivector extractor training""" + if self.initialized: + return + self.check_previous_run() + try: + self.load_corpus() + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + self.initialized = True + + def add_config(self, train_type: str, params: MetaDict) -> None: + """ + Add a trainer to the pipeline + + Parameters + ---------- + train_type: str + Type of trainer to add, one of "dubm" or "ivector" + params: dict[str, Any] + Parameters to initialize trainer + + Raises + ------ + ConfigError + If an invalid ``train_type`` is specified + """ + p = {} + p.update(self.param_dict) + p.update(params) + identifier = train_type + index = 1 + while identifier in self.training_configs: + identifier = f"{train_type}_{index}" + index += 1 + self.final_identifier = identifier + if train_type == "dubm": + config = DubmTrainer(identifier=identifier, worker=self, **p) + elif train_type == "ivector": + config = IvectorTrainer(identifier=identifier, worker=self, **p) + else: + raise ConfigError(f"Invalid training type '{train_type}' in config file") + + self.training_configs[identifier] = config + + @property + def workflow_identifier(self) -> str: + """Ivector training identifier""" + return "train_ivector" + + @property + def meta(self) -> MetaDict: + """Metadata about the final round of training""" + return self.training_configs[self.final_identifier].meta + + def train(self) -> None: + """ + Run through the training configurations to produce a final ivector extractor model + """ + begin = time.time() + self.setup() + previous = None + for trainer in self.training_configs.values(): + self.current_subset = trainer.subset + if previous is not None: + self.current_model = IvectorExtractorModel(previous.exported_model_path) + os.makedirs(trainer.working_directory, exist_ok=True) + self.current_model.export_model(trainer.working_directory) + trainer.train() + previous = trainer + self.logger.info(f"Completed training in {time.time()-begin} seconds!") + + def export_model(self, output_model_path: str) -> None: + """ + Export an ivector extractor model to the specified path + + Parameters + ---------- + output_model_path : str + Path to save ivector extractor model + """ + self.training_configs[self.final_identifier].export_model(output_model_path) + self.logger.info(f"Saved model to {output_model_path}") + + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse configuration parameters from a config file and command line arguments + + Parameters + ---------- + config_path: str, optional + Path to yaml configuration file + args: :class:`~argparse.Namespace`, optional + Arguments parsed by argparse + unknown_args: list[str], optional + List of unknown arguments from argparse + + Returns + ------- + dict[str, Any] + Dictionary of specified configuration parameters + """ + global_params = {} + training_params = [] + if config_path is not None: + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + training_params = [] + for k, v in data.items(): + if k == "training": + for t in v: + for k2, v2 in t.items(): + if "features" in v2: + global_params.update(v2["features"]) + del v2["features"] + training_params.append((k2, v2)) + elif k == "features": + if "type" in v: + v["feature_type"] = v["type"] + del v["type"] + global_params.update(v) + else: + global_params[k] = v + if not training_params: + raise ConfigError(f"No 'training' block found in {config_path}") + else: # default training configuration + training_params.append(("dubm", {})) + # training_params.append(("ubm", {})) + training_params.append(("ivector", {})) + if training_params: + if training_params[0][0] != "dubm": + raise ConfigError("The first round of training must be dubm.") + global_params["training_configuration"] = training_params + global_params.update(cls.parse_args(args, unknown_args)) + return global_params diff --git a/montreal_forced_aligner/language_modeling/__init__.py b/montreal_forced_aligner/language_modeling/__init__.py new file mode 100644 index 00000000..997cdd93 --- /dev/null +++ b/montreal_forced_aligner/language_modeling/__init__.py @@ -0,0 +1,14 @@ +""" +Language modeling +================= + + +""" + +from montreal_forced_aligner.language_modeling.trainer import ( + LmArpaTrainer, + LmCorpusTrainer, + LmDictionaryCorpusTrainer, +) + +__all__ = ["LmCorpusTrainer", "LmDictionaryCorpusTrainer", "LmArpaTrainer"] diff --git a/montreal_forced_aligner/language_modeling/trainer.py b/montreal_forced_aligner/language_modeling/trainer.py new file mode 100644 index 00000000..8ab481c5 --- /dev/null +++ b/montreal_forced_aligner/language_modeling/trainer.py @@ -0,0 +1,443 @@ +"""Classes for training language models""" +from __future__ import annotations + +import os +import re +import subprocess +from typing import TYPE_CHECKING, Generator + +from montreal_forced_aligner.abc import TopLevelMfaWorker, TrainerMixin +from montreal_forced_aligner.corpus.text_corpus import MfaWorker, TextCorpusMixin +from montreal_forced_aligner.dictionary.mixins import DictionaryMixin +from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin +from montreal_forced_aligner.models import LanguageModel + +if TYPE_CHECKING: + from montreal_forced_aligner.abc import MetaDict + +__all__ = ["LmCorpusTrainer", "LmTrainerMixin", "LmArpaTrainer", "LmDictionaryCorpusTrainer"] + + +class LmTrainerMixin(DictionaryMixin, TrainerMixin, MfaWorker): + """ + Abstract mixin class for training language models + + Parameters + ---------- + prune_method: str + Pruning method for pruning the ngram model, defaults to "relative_entropy" + prune_thresh_small: float + Pruning threshold for the small language model, defaults to 0.0000003 + prune_thresh_medium: float + Pruning threshold for the medium language model, defaults to 0.0000001 + + See Also + -------- + :class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin` + For dictionary parsing parameters + :class:`~montreal_forced_aligner.abc.TrainerMixin` + For training parameters + :class:`~montreal_forced_aligner.abc.MfaWorker` + For worker parameters + """ + + def __init__( + self, + prune_method="relative_entropy", + order: int = 3, + method: str = "kneser_ney", + prune_thresh_small=0.0000003, + prune_thresh_medium=0.0000001, + **kwargs, + ): + super().__init__(**kwargs) + self.prune_method = prune_method + self.order = order + self.method = method + self.prune_thresh_small = prune_thresh_small + self.prune_thresh_medium = prune_thresh_medium + + @property + def mod_path(self) -> str: + """Internal temporary path to the model file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.mod") + + @property + def far_path(self) -> str: + """Internal temporary path to the FAR file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.far") + + @property + def large_arpa_path(self) -> str: + """Internal temporary path to the large arpa file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.arpa") + + @property + def medium_arpa_path(self) -> str: + """Internal temporary path to the medium arpa file""" + return self.large_arpa_path.replace(".arpa", "_med.arpa") + + @property + def small_arpa_path(self) -> str: + """Internal temporary path to the small arpa file""" + return self.large_arpa_path.replace(".arpa", "_small.arpa") + + def initialize_training(self) -> None: + """Initialize training""" + pass + + def train_iteration(self) -> None: + """Run one training iteration""" + pass + + def finalize_training(self) -> None: + """Run one training iteration""" + pass + + def prune_large_language_model(self) -> None: + """Prune the large language model into small and medium versions""" + self.log_info("Pruning large ngram model to medium and small versions...") + small_mod_path = self.mod_path.replace(".mod", "_small.mod") + med_mod_path = self.mod_path.replace(".mod", "_med.mod") + subprocess.call( + [ + "ngramshrink", + f"--method={self.prune_method}", + f"--theta={self.prune_thresh_medium}", + self.mod_path, + med_mod_path, + ] + ) + subprocess.call(["ngramprint", "--ARPA", med_mod_path, self.medium_arpa_path]) + + self.log_debug("Finished pruning medium arpa!") + subprocess.call( + [ + "ngramshrink", + f"--method={self.prune_method}", + f"--theta={self.prune_thresh_small}", + self.mod_path, + small_mod_path, + ] + ) + subprocess.call(["ngramprint", "--ARPA", small_mod_path, self.small_arpa_path]) + + self.log_debug("Finished pruning small arpa!") + self.log_info("Done pruning!") + + def export_model(self, output_model_path: str) -> None: + """ + Export language model to specified path + + Parameters + ---------- + output_model_path:str + Path to export model + """ + directory, filename = os.path.split(output_model_path) + basename, _ = os.path.splitext(filename) + model_temp_dir = os.path.join(self.working_directory, "model_archiving") + os.makedirs(model_temp_dir, exist_ok=True) + model = LanguageModel.empty(basename, root_directory=model_temp_dir) + model.add_meta_file(self) + model.add_arpa_file(self.large_arpa_path) + model.add_arpa_file(self.medium_arpa_path) + model.add_arpa_file(self.small_arpa_path) + basename, _ = os.path.splitext(output_model_path) + model.dump(basename) + + +class LmCorpusTrainer(LmTrainerMixin, TextCorpusMixin, TopLevelMfaWorker): + """ + Top-level worker to train a language model from a text corpus + + Parameters + ---------- + order: int + Ngram order, defaults to 3 + method:str + Smoothing method for the ngram model, defaults to "kneser_ney" + count_threshold:int + Minimum count needed to not be treated as an OOV item, defaults to 1 + + See Also + -------- + :class:`~montreal_forced_aligner.language_modeling.trainer.LmTrainerMixin` + For language model training parsing parameters + :class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin` + For corpus parsing parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + """ + + def __init__(self, count_threshold: int = 1, **kwargs): + super().__init__(**kwargs) + self.count_threshold = count_threshold + + def setup(self) -> None: + """Set up language model training""" + if self.initialized: + return + os.makedirs(self.working_log_directory, exist_ok=True) + self._load_corpus() + + with open(self.training_path, "w", encoding="utf8") as f: + for text in self.normalized_text_iter(self.count_threshold): + f.write(f"{text}\n") + + self.save_oovs_found(self.working_directory) + + subprocess.call( + ["ngramsymbols", f'--OOV_symbol="{self.oov_word}"', self.training_path, self.sym_path] + ) + self.initialized = True + + @property + def training_path(self): + """Internal path to training data""" + return os.path.join(self.working_directory, "training.txt") + + @property + def sym_path(self): + """Internal path to symbols file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.sym") + + @property + def far_path(self): + """Internal path to FAR file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.far") + + @property + def cnts_path(self): + """Internal path to counts file""" + return os.path.join(self.working_directory, f"{self.data_source_identifier}.cnts") + + @property + def workflow_identifier(self) -> str: + """Language model trainer identifier""" + return "train_lm_corpus" + + @property + def meta(self) -> MetaDict: + """Metadata information for the language model""" + from ..utils import get_mfa_version + + return { + "type": "ngram", + "order": self.order, + "method": self.method, + "version": get_mfa_version(), + } + + def evaluate(self) -> None: + """ + Run an evaluation over the training data to generate perplexity score + """ + log_path = os.path.join(self.working_log_directory, "evaluate.log") + + small_mod_path = self.mod_path.replace(".mod", "_small.mod") + med_mod_path = self.mod_path.replace(".mod", "_med.mod") + with open(log_path, "w", encoding="utf8") as log_file: + perplexity_proc = subprocess.Popen( + [ + "ngramperplexity", + f'--OOV_symbol="{self.oov_word}"', + self.mod_path, + self.far_path, + ], + stdout=subprocess.PIPE, + stderr=log_file, + text=True, + ) + stdout, stderr = perplexity_proc.communicate() + num_sentences = None + num_words = None + num_oovs = None + perplexity = None + for line in stdout.splitlines(): + m = re.search(r"(\d+) sentences", line) + if m: + num_sentences = m.group(0) + m = re.search(r"(\d+) words", line) + if m: + num_words = m.group(0) + m = re.search(r"(\d+) OOVs", line) + if m: + num_oovs = m.group(0) + m = re.search(r"perplexity = ([\d.]+)", line) + if m: + perplexity = m.group(0) + + self.log_info(f"{num_sentences} sentences, {num_words} words, {num_oovs} oovs") + self.log_info(f"Perplexity of large model: {perplexity}") + + perplexity_proc = subprocess.Popen( + [ + "ngramperplexity", + f'--OOV_symbol="{self.oov_word}"', + med_mod_path, + self.far_path, + ], + stdout=subprocess.PIPE, + stderr=log_file, + text=True, + ) + stdout, stderr = perplexity_proc.communicate() + + perplexity = None + for line in stdout.splitlines(): + m = re.search(r"perplexity = ([\d.]+)", line) + if m: + perplexity = m.group(0) + self.log_info(f"Perplexity of medium model: {perplexity}") + perplexity_proc = subprocess.Popen( + [ + "ngramperplexity", + f'--OOV_symbol="{self.oov_word}"', + small_mod_path, + self.far_path, + ], + stdout=subprocess.PIPE, + stderr=log_file, + text=True, + ) + stdout, stderr = perplexity_proc.communicate() + + perplexity = None + for line in stdout.splitlines(): + m = re.search(r"perplexity = ([\d.]+)", line) + if m: + perplexity = m.group(0) + self.log_info(f"Perplexity of small model: {perplexity}") + + def normalized_text_iter(self, min_count: int = 1) -> Generator: + """ + Construct an iterator over the normalized texts in the corpus + + Parameters + ---------- + min_count: int + Minimum word count to include in the output, otherwise will use OOV code, defaults to 1 + + Yields + ------- + str + Normalized text + """ + unk_words = {k for k, v in self.word_counts.items() if v <= min_count} + for u in self.utterances.values(): + text = u.text.split() + new_text = [] + for t in text: + if u.speaker.dictionary is not None: + u.speaker.dictionary.to_int(t) + lookup = u.speaker.dictionary.split_clitics(t) + if lookup is None: + continue + else: + lookup = [t] + for item in lookup: + if item in unk_words: + new_text.append(self.oov_word) + self.oovs_found[item] += 1 + elif ( + u.speaker.dictionary is not None and item not in u.speaker.dictionary.words + ): + new_text.append(self.oov_word) + else: + new_text.append(item) + yield " ".join(new_text) + + def train(self) -> None: + """ + Train a language model + """ + self.log_info("Beginning training large ngram model...") + subprocess.call( + [ + "farcompilestrings", + "--fst_type=compact", + f'--unknown_symbol="{self.oov_word}"', + f"--symbols={self.sym_path}", + "--keep_symbols", + self.training_path, + self.far_path, + ] + ) + subprocess.call(["ngramcount", f"--order={self.order}", self.far_path, self.cnts_path]) + subprocess.call(["ngrammake", f"--method={self.method}", self.cnts_path, self.mod_path]) + self.log_info("Done!") + + subprocess.call(["ngramprint", "--ARPA", self.mod_path, self.large_arpa_path]) + + self.log_info("Large ngam model created!") + + self.prune_large_language_model() + self.evaluate() + + +class LmDictionaryCorpusTrainer(MultispeakerDictionaryMixin, LmCorpusTrainer): + """ + Top-level worker to train a language model and incorporate a pronunciation dictionary for marking words as OOV + + See Also + -------- + :class:`~montreal_forced_aligner.language_modeling.trainer.LmTrainerMixin` + For language model training parsing parameters + :class:`~montreal_forced_aligner.dictionary.multispeaker.MultispeakerDictionaryMixin` + For dictionary parsing parameters + """ + + pass + + +class LmArpaTrainer(LmTrainerMixin, TopLevelMfaWorker): + """ + Top-level worker to convert an existing ARPA-format language model to MFA format + + See Also + -------- + :class:`~montreal_forced_aligner.language_modeling.trainer.LmTrainerMixin` + For language model training parsing parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parsing parameters + """ + + def __init__(self, arpa_path: str, **kwargs): + self.arpa_path = arpa_path + super().__init__(**kwargs) + + def setup(self) -> None: + """Set up language model training""" + os.makedirs(self.working_log_directory, exist_ok=True) + with open(self.arpa_path, "r", encoding="utf8") as inf, open( + self.large_arpa_path, "w", encoding="utf8" + ) as outf: + for line in inf: + outf.write(line.lower()) + self.initialized = True + + @property + def data_directory(self) -> str: + return "" + + @property + def workflow_identifier(self) -> str: + return "train_lm_from_arpa" + + @property + def data_source_identifier(self) -> str: + return os.path.splitext(os.path.basename(self.arpa_path))[0] + + @property + def meta(self) -> MetaDict: + return {} + + def train(self) -> None: + """Convert the arpa model to MFA format""" + self.log_info("Parsing large ngram model...") + subprocess.call(["ngramread", "--ARPA", self.large_arpa_path, self.mod_path]) + + self.log_info("Large ngam model parsed!") + + self.prune_large_language_model() diff --git a/montreal_forced_aligner/lm/__init__.py b/montreal_forced_aligner/lm/__init__.py deleted file mode 100644 index 290851df..00000000 --- a/montreal_forced_aligner/lm/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Language modeling -================= - - -""" - -from .trainer import LmTrainer - -__all__ = ["trainer", "LmTrainer"] - -LmTrainer.__module__ = "montreal_forced_aligner.lm" diff --git a/montreal_forced_aligner/lm/trainer.py b/montreal_forced_aligner/lm/trainer.py deleted file mode 100644 index daf64f22..00000000 --- a/montreal_forced_aligner/lm/trainer.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Classes for training language models""" -from __future__ import annotations - -import logging -import os -import re -import subprocess -from typing import TYPE_CHECKING, Dict, Optional, Union - -from ..config import TEMP_DIR -from ..corpus import Corpus -from ..models import LanguageModel - -if TYPE_CHECKING: - from ..abc import Dictionary - from ..config.train_lm_config import TrainLMConfig - - -__all__ = ["LmTrainer"] - - -class LmTrainer: - """ - Train a language model from a corpus with text, or convert an existing ARPA-format language model to MFA format - - Parameters - ---------- - source: class:`~montreal_forced_aligner.corpus.Corpus` or str - Either a alignable corpus or a path to an ARPA format language model - config : class:`~montreal_forced_aligner.config.TrainLMConfig` - Config class for training language model - output_model_path : str - Path to output trained model - dictionary : class:`~montreal_forced_aligner.dictionary.PronunciationDictionary`, optional - Optional dictionary to calculate unknown words - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - supplemental_model_path: str, optional - Path to second language model to merge with the trained model - supplemental_model_weight : float, optional - Weight of supplemental model when merging, defaults to 1 - debug : bool - Flag for debug mode - logger : :class:`~logging.Logger`, optional - Logger to send output to - """ - - def __init__( - self, - source: Union[Corpus, str], - config: TrainLMConfig, - output_model_path: str, - dictionary: Optional[Dictionary] = None, - temp_directory: Optional[str] = None, - supplemental_model_path: Optional[str] = None, - supplemental_model_weight: int = 1, - debug: bool = False, - logger: Optional[logging.Logger] = None, - ): - if not temp_directory: - temp_directory = TEMP_DIR - temp_directory = os.path.join(temp_directory, "LM") - self.debug = debug - self.name, _ = os.path.splitext(os.path.basename(output_model_path)) - self.temp_directory = os.path.join(temp_directory, self.name) - self.models_temp_dir = os.path.join(self.temp_directory, "models") - self.log_directory = os.path.join(self.temp_directory, "logs") - self.log_file = os.path.join(self.log_directory, "train_lm.log") - os.makedirs(self.log_directory, exist_ok=True) - if logger is None: - self.logger = logging.getLogger("train_lm") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - else: - self.logger = logger - self.source = source - self.dictionary = dictionary - self.output_model_path = output_model_path - self.config = config - self.supplemental_model_path = supplemental_model_path - self.source_model_weight = 1 - self.supplemental_model_weight = supplemental_model_weight - - @property - def meta(self) -> Dict[str, Union[str, int, float]]: - """Metadata information for the language model""" - from ..utils import get_mfa_version - - return { - "type": "ngram", - "order": self.config.order, - "method": self.config.method, - "prune": self.config.prune, - "version": get_mfa_version(), - } - - def evaluate(self) -> None: - """ - Run an evaluation over the training data to generate perplexity score - """ - log_path = os.path.join(self.log_directory, "evaluate.log") - mod_path = os.path.join(self.temp_directory, self.name + ".mod") - far_path = os.path.join(self.temp_directory, self.name + ".far") - small_mod_path = mod_path.replace(".mod", "_small.mod") - med_mod_path = mod_path.replace(".mod", "_med.mod") - with open(log_path, "w", encoding="utf8") as log_file: - perplexity_proc = subprocess.Popen( - ["ngramperplexity", '--OOV_symbol=""', mod_path, far_path], - stdout=subprocess.PIPE, - stderr=log_file, - text=True, - ) - stdout, stderr = perplexity_proc.communicate() - num_sentences = None - num_words = None - num_oovs = None - perplexity = None - for line in stdout.splitlines(): - m = re.search(r"(\d+) sentences", line) - if m: - num_sentences = m.group(0) - m = re.search(r"(\d+) words", line) - if m: - num_words = m.group(0) - m = re.search(r"(\d+) OOVs", line) - if m: - num_oovs = m.group(0) - m = re.search(r"perplexity = ([\d.]+)", line) - if m: - perplexity = m.group(0) - - self.logger.info(f"{num_sentences} sentences, {num_words} words, {num_oovs} oovs") - self.logger.info(f"Perplexity of large model: {perplexity}") - - perplexity_proc = subprocess.Popen( - ["ngramperplexity", '--OOV_symbol=""', med_mod_path, far_path], - stdout=subprocess.PIPE, - stderr=log_file, - text=True, - ) - stdout, stderr = perplexity_proc.communicate() - - perplexity = None - for line in stdout.splitlines(): - m = re.search(r"perplexity = ([\d.]+)", line) - if m: - perplexity = m.group(0) - self.logger.info(f"Perplexity of medium model: {perplexity}") - perplexity_proc = subprocess.Popen( - ["ngramperplexity", '--OOV_symbol=""', small_mod_path, far_path], - stdout=subprocess.PIPE, - stderr=log_file, - text=True, - ) - stdout, stderr = perplexity_proc.communicate() - - perplexity = None - for line in stdout.splitlines(): - m = re.search(r"perplexity = ([\d.]+)", line) - if m: - perplexity = m.group(0) - self.logger.info(f"Perplexity of small model: {perplexity}") - - def train(self) -> None: - """ - Train a language model - """ - mod_path = os.path.join(self.temp_directory, f"{self.name}.mod") - large_model_path = os.path.join(self.temp_directory, f"{self.name}.arpa") - small_output_path = large_model_path.replace(".arpa", "_small.arpa") - med_output_path = large_model_path.replace(".arpa", "_med.arpa") - if isinstance(self.source, Corpus): - self.logger.info("Beginning training large ngram model...") - sym_path = os.path.join(self.temp_directory, f"{self.name}.sym") - far_path = os.path.join(self.temp_directory, f"{self.name}.far") - cnts_path = os.path.join(self.temp_directory, f"{self.name}.cnts") - training_path = os.path.join(self.temp_directory, "training.txt") - - with open(training_path, "w", encoding="utf8") as f: - for text in self.source.normalized_text_iter(self.config.count_threshold): - f.write(f"{text}\n") - - if self.dictionary is not None: - self.dictionary.save_oovs_found(self.temp_directory) - - subprocess.call(["ngramsymbols", '--OOV_symbol=""', training_path, sym_path]) - subprocess.call( - [ - "farcompilestrings", - "--fst_type=compact", - '--unknown_symbol=""', - "--symbols=" + sym_path, - "--keep_symbols", - training_path, - far_path, - ] - ) - subprocess.call(["ngramcount", f"--order={self.config.order}", far_path, cnts_path]) - subprocess.call(["ngrammake", f"--method={self.config.method}", cnts_path, mod_path]) - self.logger.info("Done!") - else: - self.logger.info("Parsing large ngram model...") - temp_text_path = os.path.join(self.temp_directory, "input.arpa") - with open(self.source, "r", encoding="utf8") as inf, open( - temp_text_path, "w", encoding="utf8" - ) as outf: - for line in inf: - outf.write(line.lower()) - subprocess.call(["ngramread", "--ARPA", temp_text_path, mod_path]) - os.remove(temp_text_path) - if self.supplemental_model_path: - self.logger.info("Parsing supplemental ngram model...") - supplemental_path = os.path.join(self.temp_directory, "extra.mod") - merged_path = os.path.join(self.temp_directory, "merged.mod") - subprocess.call( - ["ngramread", "--ARPA", self.supplemental_model_path, supplemental_path] - ) - self.logger.info("Merging both ngram models to create final large model...") - subprocess.call( - [ - "ngrammerge", - "--normalize", - f"--alpha={self.source_model_weight}", - f"--beta={self.supplemental_model_weight}", - mod_path, - supplemental_path, - merged_path, - ] - ) - mod_path = merged_path - - subprocess.call(["ngramprint", "--ARPA", mod_path, large_model_path]) - - self.logger.info("Large ngam model created!") - directory, filename = os.path.split(self.output_model_path) - basename, _ = os.path.splitext(filename) - - if self.config.prune: - self.logger.info("Pruning large ngram model to medium and small versions...") - small_mod_path = mod_path.replace(".mod", "_small.mod") - med_mod_path = mod_path.replace(".mod", "_med.mod") - subprocess.call( - [ - "ngramshrink", - "--method=relative_entropy", - f"--theta={self.config.prune_thresh_small}", - mod_path, - small_mod_path, - ] - ) - subprocess.call( - [ - "ngramshrink", - "--method=relative_entropy", - f"--theta={self.config.prune_thresh_medium}", - mod_path, - med_mod_path, - ] - ) - subprocess.call(["ngramprint", "--ARPA", small_mod_path, small_output_path]) - subprocess.call(["ngramprint", "--ARPA", med_mod_path, med_output_path]) - self.logger.info("Done!") - self.evaluate() - model = LanguageModel.empty(basename, root_directory=self.models_temp_dir) - model.add_meta_file(self) - model.add_arpa_file(large_model_path) - if self.config.prune: - model.add_arpa_file(med_output_path) - model.add_arpa_file(small_output_path) - basename, _ = os.path.splitext(self.output_model_path) - model.dump(basename) - # model.clean_up() diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index 7d1f0437..a6954e9b 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -6,26 +6,28 @@ from __future__ import annotations import os +import shutil from shutil import copy, copyfile, make_archive, move, rmtree, unpack_archive -from typing import TYPE_CHECKING, Collection, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Collection, Optional, Union import yaml -from .abc import Dictionary, MetaDict, MfaModel, Trainer -from .exceptions import ( +from montreal_forced_aligner.abc import MfaModel, ModelExporterMixin +from montreal_forced_aligner.exceptions import ( LanguageModelNotFoundError, ModelLoadError, PronunciationAcousticMismatchError, ) -from .helper import TerminalPrinter +from montreal_forced_aligner.helper import TerminalPrinter if TYPE_CHECKING: from logging import Logger - from .config import FeatureConfig - from .config.dictionary_config import DictionaryConfig - from .config.train_config import TrainingConfig - from .dictionary import PronunciationDictionary + from montreal_forced_aligner.abc import MetaDict + from montreal_forced_aligner.dictionary.pronunciation import ( + DictionaryMixin, + PronunciationDictionaryMixin, + ) # default format for output @@ -61,10 +63,10 @@ class Archive(MfaModel): extensions = [".zip"] def __init__(self, source: str, root_directory: Optional[str] = None): - from .config import TEMP_DIR + from .config import get_temporary_directory if root_directory is None: - root_directory = TEMP_DIR + root_directory = os.path.join(get_temporary_directory(), "extracted_models") self.root_directory = root_directory self.source = source self._meta = {} @@ -73,15 +75,45 @@ def __init__(self, source: str, root_directory: Optional[str] = None): self.dirname = os.path.abspath(source) else: self.dirname = os.path.join(root_directory, self.name) - if not os.path.exists(self.dirname): - os.makedirs(root_directory, exist_ok=True) - unpack_archive(source, self.dirname) - files = os.listdir(self.dirname) - old_dir_path = os.path.join(self.dirname, files[0]) - if len(files) == 1 and os.path.isdir(old_dir_path): # Backwards compatibility - for f in os.listdir(old_dir_path): - move(os.path.join(old_dir_path, f), os.path.join(self.dirname, f)) - os.rmdir(old_dir_path) + if os.path.exists(self.dirname): + shutil.rmtree(self.dirname, ignore_errors=True) + + os.makedirs(root_directory, exist_ok=True) + unpack_archive(source, self.dirname) + files = os.listdir(self.dirname) + old_dir_path = os.path.join(self.dirname, files[0]) + if len(files) == 1 and os.path.isdir(old_dir_path): # Backwards compatibility + for f in os.listdir(old_dir_path): + move(os.path.join(old_dir_path, f), os.path.join(self.dirname, f)) + os.rmdir(old_dir_path) + + def parse_old_features(self) -> None: + """ + Parse MFA model's features and ensure that they are up-to-date with current functionality + """ + if "features" not in self._meta: + return + feature_key_remapping = { + "type": "feature_type", + "deltas": "uses_deltas", + "lda": "uses_splices", + "fmllr": "uses_speaker_adaptation", + } + + for key, new_key in feature_key_remapping.items(): + if key in self._meta["features"]: + self._meta["features"][new_key] = self._meta["features"][key] + del self._meta["features"][key] + if "uses_splices" not in self._meta["features"]: # Backwards compatibility + self._meta["features"]["uses_splices"] = os.path.exists( + os.path.join(self.dirname, "lda.mat") + ) + if "multilingual_ipa" not in self._meta: + self._meta["multilingual_ipa"] = False + if "uses_speaker_adaptation" not in self._meta["features"]: + self._meta["features"]["uses_speaker_adaptation"] = os.path.exists( + os.path.join(self.dirname, "final.alimdl") + ) def get_subclass_object( self, @@ -91,8 +123,13 @@ def get_subclass_object( Returns ------- - Union[AcousticModel, G2PModel, LanguageModel, IvectorExtractor] + :class:`~montreal_forced_aligner.models.AcousticModel`, :class:`~montreal_forced_aligner.models.G2PModel`, :class:`~montreal_forced_aligner.models.LanguageModel`, or :class:`~montreal_forced_aligner.models.IvectorExtractorModel` Subclass model that was auto detected + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.ModelLoadError` + If the model type cannot be determined """ for f in os.listdir(self.dirname): if f == "tree": @@ -149,7 +186,7 @@ def generate_path(cls, root: str, name: str, enforce_existence: bool = True) -> return path return None - def pretty_print(self): + def pretty_print(self) -> None: """ Pretty print the archive's meta data using TerminalPrinter """ @@ -166,22 +203,25 @@ def meta(self) -> dict: meta_path = os.path.join(self.dirname, "meta.yaml") with open(meta_path, "r", encoding="utf8") as f: self._meta = yaml.safe_load(f) + self.parse_old_features() return self._meta - def add_meta_file(self, trainer: Trainer) -> None: + def add_meta_file(self, trainer: ModelExporterMixin) -> None: """ Add a metadata file from a given trainer to the model Parameters ---------- - trainer: Trainer + trainer: :class:`~montreal_forced_aligner.abc.ModelExporterMixin` The trainer to construct the metadata from """ with open(os.path.join(self.dirname, "meta.yaml"), "w", encoding="utf8") as f: yaml.dump(trainer.meta, f) @classmethod - def empty(cls, head: str, root_directory: Optional[str] = None) -> Archive: + def empty( + cls, head: str, root_directory: Optional[str] = None + ) -> Union[Archive, IvectorExtractorModel, AcousticModel, G2PModel, LanguageModel]: """ Initialize an archive using an empty directory @@ -194,13 +234,13 @@ def empty(cls, head: str, root_directory: Optional[str] = None) -> Archive: Returns ------- - Archive + :class:`~montreal_forced_aligner.models.Archive`, :class:`~montreal_forced_aligner.models.AcousticModel`, :class:`~montreal_forced_aligner.models.G2PModel`, :class:`~montreal_forced_aligner.models.LanguageModel`, or :class:`~montreal_forced_aligner.models.IvectorExtractorModel` Model constructed from the empty directory """ - from .config import TEMP_DIR + from .config import get_temporary_directory if root_directory is None: - root_directory = TEMP_DIR + root_directory = get_temporary_directory() os.makedirs(root_directory, exist_ok=True) source = os.path.join(root_directory, head) @@ -255,45 +295,34 @@ class AcousticModel(Archive): files = ["final.mdl", "final.alimdl", "final.occs", "lda.mat", "tree"] extensions = [".zip", ".am"] - def add_meta_file(self, trainer: Trainer) -> None: + model_type = "acoustic" + + def __init__(self, source: str, root_directory: Optional[str] = None): + if source in AcousticModel.get_available_models(): + source = AcousticModel.get_pretrained_path(source) + + super().__init__(source, root_directory) + + def add_meta_file(self, trainer: ModelExporterMixin) -> None: """ Add metadata file from a model trainer Parameters ---------- - trainer: :class:`~montreal_forced_aligner.abc.Trainer` + trainer: :class:`~montreal_forced_aligner.abc.ModelExporterMixin` Trainer to supply metadata information about the acoustic model """ with open(os.path.join(self.dirname, "meta.yaml"), "w", encoding="utf8") as f: yaml.dump(trainer.meta, f) @property - def feature_config(self) -> FeatureConfig: - """ - Return the FeatureConfig used in training the model - """ - from .config.feature_config import FeatureConfig - - fc = FeatureConfig() - fc.update(self.meta["features"]) - return fc - - def adaptation_config(self) -> Tuple[TrainingConfig, DictionaryConfig]: - """ - Generate an adaptation configuration - - Returns - ------- - TrainingConfig - Configuration to be used in adapting the acoustic model to new data - """ - from .config.train_config import load_no_sat_adapt, load_sat_adapt - - if self.meta["features"]["fmllr"]: - train, align, dictionary = load_sat_adapt() - else: - train, align, dictionary = load_no_sat_adapt() - return train, dictionary + def parameters(self) -> MetaDict: + """Parameters to pass to top-level workers""" + params = {**self.meta["features"]} + for key in ["multilingual_ipa"]: + params[key] = self.meta[key] + params["non_silence_phones"] = {x for x in self.meta["phones"]} + return params @property def meta(self) -> MetaDict: @@ -301,11 +330,25 @@ def meta(self) -> MetaDict: Metadata information for the acoustic model """ default_features = { - "type": "mfcc", + "feature_type": "mfcc", "use_energy": False, "frame_shift": 10, + "snip_edges": True, + "low_frequency": 20, + "high_frequency": 7800, + "sample_frequency": 16000, + "allow_downsample": True, + "allow_upsample": True, "pitch": False, - "fmllr": True, + "uses_cmvn": True, + "uses_deltas": True, + "uses_splices": False, + "uses_voiced": False, + "uses_speaker_adaptation": False, + "silence_weight": 0.0, + "fmllr_update_type": "full", + "splice_left_context": 3, + "splice_right_context": 3, } if not self._meta: meta_path = os.path.join(self.dirname, "meta.yaml") @@ -321,18 +364,10 @@ def meta(self) -> MetaDict: self._meta = yaml.safe_load(f) if self._meta["features"] == "mfcc+deltas": self._meta["features"] = default_features - if "uses_lda" not in self._meta: # Backwards compatibility - self._meta["uses_lda"] = os.path.exists(os.path.join(self.dirname, "lda.mat")) - if "multilingual_ipa" not in self._meta: - self._meta["multilingual_ipa"] = False - if "uses_sat" not in self._meta: - self._meta["uses_sat"] = False if "phone_type" not in self._meta: self._meta["phone_type"] = "triphone" self._meta["phones"] = set(self._meta.get("phones", [])) - self._meta["has_speaker_independent_model"] = os.path.exists( - os.path.join(self.dirname, "final.alimdl") - ) + self.parse_old_features() return self._meta def pretty_print(self) -> None: @@ -358,8 +393,10 @@ def pretty_print(self) -> None: configuration_data["Acoustic model"]["data"]["Architecture"] = self.meta["architecture"] configuration_data["Acoustic model"]["data"]["Phone type"] = self.meta["phone_type"] configuration_data["Acoustic model"]["data"]["Features"] = { - "Type": self.meta["features"]["type"], + "Feature type": self.meta["features"]["feature_type"], "Frame shift": self.meta["features"]["frame_shift"], + "Performs speaker adaptation": self.meta["features"]["uses_speaker_adaptation"], + "Performs LDA on features": self.meta["features"]["uses_splices"], } if self.meta["phones"]: configuration_data["Acoustic model"]["data"]["Phones"] = self.meta["phones"] @@ -368,9 +405,6 @@ def pretty_print(self) -> None: configuration_data["Acoustic model"]["data"]["Configuration options"] = { "Multilingual IPA": self.meta["multilingual_ipa"], - "Performs speaker adaptation": self.meta["uses_sat"], - "Has speaker-independent model": self.meta["has_speaker_independent_model"], - "Performs LDA on features": self.meta["uses_lda"], } printer.print_config(configuration_data) @@ -423,25 +457,25 @@ def log_details(self, logger: Logger) -> None: logger.debug(stream) logger.debug("") - def validate(self, dictionary: Union[Dictionary, G2PModel]) -> None: + def validate(self, dictionary: DictionaryMixin) -> None: """ Validate this acoustic model against a pronunciation dictionary or G2P model to ensure their phone sets are compatible Parameters ---------- - dictionary: Union[DictionaryConfig, G2PModel] - PronunciationDictionary or G2P model to compare phone sets with + dictionary: Union[:class:`~montreal_forced_aligner.dictionary.pronunciation.PronunciationDictionaryMixin`, :class:`~montreal_forced_aligner.models.G2PModel`] + PronunciationDictionaryMixin or G2P model to compare phone sets with Raises ------ - PronunciationAcousticMismatchError + :class:`~montreal_forced_aligner.exceptions.PronunciationAcousticMismatchError` If there are phones missing from the acoustic model """ if isinstance(dictionary, G2PModel): missing_phones = dictionary.meta["phones"] - set(self.meta["phones"]) else: - missing_phones = dictionary.config.non_silence_phones - set(self.meta["phones"]) + missing_phones = dictionary.non_silence_phones - set(self.meta["phones"]) if missing_phones: raise (PronunciationAcousticMismatchError(missing_phones)) @@ -451,6 +485,8 @@ class IvectorExtractorModel(Archive): Model class for IvectorExtractor """ + model_type = "ivector" + model_files = [ "final.ie", "final.ubm", @@ -461,6 +497,20 @@ class IvectorExtractorModel(Archive): ] extensions = [".zip", ".ivector"] + def __init__(self, source: str, root_directory: Optional[str] = None): + if source in IvectorExtractorModel.get_available_models(): + source = IvectorExtractorModel.get_pretrained_path(source) + + super().__init__(source, root_directory) + + @property + def parameters(self) -> MetaDict: + """Parameters to pass to top-level workers""" + params = {**self.meta["features"]} + for key in ["ivector_dimension", "num_gselect", "min_post", "posterior_scale"]: + params[key] = self.meta[key] + return params + def add_model(self, source: str) -> None: """ Add file into archive @@ -488,31 +538,28 @@ def export_model(self, destination: str) -> None: if os.path.exists(os.path.join(self.dirname, filename)): copyfile(os.path.join(self.dirname, filename), os.path.join(destination, filename)) - @property - def feature_config(self) -> FeatureConfig: - """ - Return the FeatureConfig used in training the model - """ - from .config.feature_config import FeatureConfig - - fc = FeatureConfig() - fc.update(self.meta["features"]) - return fc - class G2PModel(Archive): extensions = [".zip", ".g2p"] + model_type = "g2p" + + def __init__(self, source: str, root_directory: Optional[str] = None): + if source in G2PModel.get_available_models(): + source = G2PModel.get_pretrained_path(source) + + super().__init__(source, root_directory) + def add_meta_file( - self, dictionary: PronunciationDictionary, architecture: Optional[str] = None + self, dictionary: PronunciationDictionaryMixin, architecture: Optional[str] = None ) -> None: """ Construct meta data information for the G2P model from the dictionary it was trained from Parameters ---------- - dictionary: PronunciationDictionary - PronunciationDictionary that was the training data for the G2P model + dictionary: :class:`~montreal_forced_aligner.dictionary.pronunciation.PronunciationDictionaryMixin` + Pronunciation dictionary that was the training data for the G2P model architecture: str, optional Architecture of the G2P model, defaults to "pynini" """ @@ -522,7 +569,7 @@ def add_meta_file( architecture = "pynini" with open(os.path.join(self.dirname, "meta.yaml"), "w", encoding="utf8") as f: meta = { - "phones": sorted(dictionary.config.non_silence_phones), + "phones": sorted(dictionary.non_silence_phones), "graphemes": sorted(dictionary.graphemes), "architecture": architecture, "version": get_mfa_version(), @@ -559,7 +606,7 @@ def add_sym_path(self, source_directory: str) -> None: Parameters ---------- - source: str + source_directory: str Source directory path """ if not os.path.exists(self.sym_path): @@ -571,7 +618,7 @@ def add_fst_model(self, source_directory: str) -> None: Parameters ---------- - source: str + source_directory: str Source directory path """ if not os.path.exists(self.fst_path): @@ -622,30 +669,29 @@ class LanguageModel(Archive): Class for MFA language models """ + model_type = "language_model" + arpa_extension = ".arpa" extensions = [f".{FORMAT}", arpa_extension, ".lm"] def __init__(self, source: str, root_directory: Optional[str] = None): - from .config import TEMP_DIR + if source in LanguageModel.get_available_models(): + source = LanguageModel.get_pretrained_path(source) + from .config import get_temporary_directory if root_directory is None: - root_directory = TEMP_DIR - self.root_directory = root_directory - self._meta = {} - self.name, _ = os.path.splitext(os.path.basename(source)) - if os.path.isdir(source): - self.dirname = os.path.abspath(source) - elif source.endswith(self.arpa_extension): + root_directory = get_temporary_directory() + + if source.endswith(self.arpa_extension): + self.root_directory = root_directory + self._meta = {} + self.name, _ = os.path.splitext(os.path.basename(source)) self.dirname = os.path.join(root_directory, self.name) if not os.path.exists(self.dirname): os.makedirs(self.dirname, exist_ok=True) copy(source, self.large_arpa_path) - elif any(source.endswith(x) for x in self.extensions): - base = root_directory - self.dirname = os.path.join(root_directory, self.name) - if not os.path.exists(self.dirname): - os.makedirs(root_directory, exist_ok=True) - unpack_archive(source, base) + else: + super().__init__(source, root_directory) @property def decode_arpa_path(self) -> str: @@ -670,12 +716,12 @@ def carpa_path(self) -> str: @property def small_arpa_path(self) -> str: """Small arpa path""" - return os.path.join(self.dirname, self.name + "_small" + self.arpa_extension) + return os.path.join(self.dirname, f"{self.name}_small{self.arpa_extension}") @property def medium_arpa_path(self) -> str: """Medium arpa path""" - return os.path.join(self.dirname, self.name + "_med" + self.arpa_extension) + return os.path.join(self.dirname, f"{self.name}_med{self.arpa_extension}") @property def large_arpa_path(self) -> str: @@ -691,8 +737,12 @@ def add_arpa_file(self, arpa_path: str) -> None: arpa_path: str Path to ARPA file """ - name = os.path.basename(arpa_path) - copyfile(arpa_path, os.path.join(self.dirname, name)) + output_name = self.large_arpa_path + if arpa_path.endswith("_small.arpa"): + output_name = self.small_arpa_path + elif arpa_path.endswith("_medium.arpa"): + output_name = self.medium_arpa_path + copyfile(arpa_path, output_name) class DictionaryModel(MfaModel): @@ -700,9 +750,13 @@ class DictionaryModel(MfaModel): Class for representing MFA pronunciation dictionaries """ + model_type = "dictionary" + extensions = [".dict", ".txt", ".yaml", ".yml"] def __init__(self, path: str): + if path in DictionaryModel.get_available_models(): + path = DictionaryModel.get_pretrained_path(path) self.path = path count = 0 self.pronunciation_probabilities = True @@ -739,12 +793,14 @@ def __init__(self, path: str): @property def meta(self) -> MetaDict: + """Metadata for the dictionary""" return { "pronunciation_probabilities": self.pronunciation_probabilities, "silence_probabilities": self.silence_probabilities, } - def add_meta_file(self, trainer: Trainer) -> None: + def add_meta_file(self, trainer: ModelExporterMixin) -> None: + """Not implemented method""" raise NotImplementedError def pretty_print(self): @@ -800,24 +856,29 @@ def generate_path(cls, root: str, name: str, enforce_existence: bool = True) -> return None @property - def is_multiple(self): + def is_multiple(self) -> bool: + """Flag for whether the dictionary contains multiple lexicons""" return os.path.splitext(self.path)[1] in [".yaml", ".yml"] @property - def name(self): + def name(self) -> str: + """Name of the dictionary""" return os.path.splitext(os.path.basename(self.path))[0] - def load_dictionary_paths(self) -> Dict[str, DictionaryModel]: - from .utils import get_available_dictionaries, get_dictionary_path + def load_dictionary_paths(self) -> dict[str, DictionaryModel]: + """ + Load the pronunciation dictionaries + Returns + ------- + dict[str, :class:`~montreal_forced_aligner.models.DictionaryModel`] + Mapping of component pronunciation dictionaries + """ mapping = {} if self.is_multiple: - available_langs = get_available_dictionaries() with open(self.path, "r", encoding="utf8") as f: data = yaml.safe_load(f) for speaker, path in data.items(): - if path in available_langs: - path = get_dictionary_path(path) mapping[speaker] = DictionaryModel(path) else: mapping["default"] = self diff --git a/montreal_forced_aligner/multiprocessing/__init__.py b/montreal_forced_aligner/multiprocessing/__init__.py deleted file mode 100644 index a6ef7a77..00000000 --- a/montreal_forced_aligner/multiprocessing/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Multiprocessing functions -========================= - -""" -from .alignment import calc_fmllr # noqa -from .alignment import calc_lda_mllt # noqa -from .alignment import compile_information # noqa -from .alignment import compile_train_graphs # noqa -from .alignment import compute_alignment_improvement # noqa -from .alignment import convert_ali_to_textgrids # noqa -from .alignment import convert_alignments # noqa -from .alignment import create_align_model # noqa -from .alignment import lda_acc_stats # noqa -from .alignment import mono_align_equal # noqa -from .alignment import train_map # noqa -from .alignment import tree_stats # noqa -from .alignment import ( # noqa - CleanupWordCtmProcessWorker, - CombineProcessWorker, - ExportPreparationProcessWorker, - ExportTextGridProcessWorker, - NoCleanupWordCtmProcessWorker, - PhoneCtmProcessWorker, - acc_stats, - acc_stats_func, - align, - align_func, -) -from .classes import Job # noqa -from .corpus import CorpusProcessWorker # noqa -from .helper import Counter, ProcessWorker, Stopped, run_mp, run_non_mp # noqa -from .ivector import acc_global_stats # noqa -from .ivector import acc_ivector_stats # noqa -from .ivector import extract_ivectors # noqa -from .ivector import gauss_to_post # noqa -from .ivector import gmm_gselect # noqa -from .ivector import segment_vad # noqa -from .pronunciations import generate_pronunciations # noqa -from .transcription import transcribe, transcribe_fmllr # noqa - -__all__ = [ - "alignment", - "classes", - "corpus", - "features", - "helper", - "ivector", - "pronunciations", - "transcription", -] - -Job.__module__ = "montreal_forced_aligner.multiprocessing" -CleanupWordCtmProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" -CombineProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" -PhoneCtmProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" -ExportPreparationProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" -ExportTextGridProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" -NoCleanupWordCtmProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" - -CorpusProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" -Counter.__module__ = "montreal_forced_aligner.multiprocessing" -Stopped.__module__ = "montreal_forced_aligner.multiprocessing" -ProcessWorker.__module__ = "montreal_forced_aligner.multiprocessing" diff --git a/montreal_forced_aligner/multiprocessing/alignment.py b/montreal_forced_aligner/multiprocessing/alignment.py deleted file mode 100644 index 57ad5c92..00000000 --- a/montreal_forced_aligner/multiprocessing/alignment.py +++ /dev/null @@ -1,2731 +0,0 @@ -""" -Aligment functions ------------------- - -""" -from __future__ import annotations - -import multiprocessing as mp -import os -import re -import statistics -import subprocess -import sys -import time -import traceback -from queue import Empty -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union - -from ..exceptions import AlignmentError, AlignmentExportError -from ..multiprocessing.helper import Stopped -from ..textgrid import ( - ctms_to_textgrids_non_mp, - export_textgrid, - generate_tiers, - output_textgrid_writing_errors, - parse_from_phone, - parse_from_word, - parse_from_word_no_cleanup, - process_ctm_line, -) -from ..utils import thirdparty_binary -from .helper import run_mp, run_non_mp - -if TYPE_CHECKING: - from ..abc import Aligner, CtmErrorDict, MetaDict, Trainer - from ..aligner.adapting import AdaptingAligner - from ..aligner.base import BaseAligner - from ..config.align_config import AlignConfig - from ..corpus.classes import ( - CleanupWordCtmArguments, - CombineCtmArguments, - ExportTextGridArguments, - File, - NoCleanupWordCtmArguments, - PhoneCtmArguments, - Utterance, - ) - from ..data import CtmType - from ..trainers import BaseTrainer, LdaTrainer, MonophoneTrainer, SatTrainer - - ConfigType = Union[BaseTrainer, AlignConfig] - - -queue_polling_timeout = 1 - -__all__ = [ - "acc_stats", - "align", - "mono_align_equal", - "tree_stats", - "compile_train_graphs", - "compile_information", - "convert_alignments", - "convert_ali_to_textgrids", - "compute_alignment_improvement", - "compare_alignments", - "PhoneCtmProcessWorker", - "CleanupWordCtmProcessWorker", - "NoCleanupWordCtmProcessWorker", - "CombineProcessWorker", - "ExportPreparationProcessWorker", - "ExportTextGridProcessWorker", - "calc_fmllr", - "calc_lda_mllt", - "create_align_model", - "ctms_to_textgrids_mp", - "lda_acc_stats", - "train_map", - "parse_iteration_alignments", - "convert_alignments_func", - "align_func", - "ali_to_ctm_func", - "compute_alignment_improvement_func", - "mono_align_equal_func", - "calc_fmllr_func", - "calc_lda_mllt_func", - "lda_acc_stats_func", - "tree_stats_func", - "map_acc_stats_func", - "acc_stats_two_feats_func", - "compile_information_func", - "compile_train_graphs_func", - "compile_utterance_train_graphs_func", - "test_utterances_func", - "acc_stats_func", -] - - -def acc_stats_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ali_paths: Dict[str, str], - acc_paths: Dict[str, str], - model_path: str, -) -> None: - """ - Multiprocessing function for accumulating stats in GMM training - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - Dictionary of feature strings per dictionary name - ali_paths: Dict[str, str] - Dictionary of alignment archives per dictionary name - acc_paths: Dict[str, str] - Dictionary of accumulated stats files per dictionary name - model_path: str - Path to the acoustic model file - """ - model_path = model_path - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - acc_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-acc-stats-ali"), - model_path, - feature_strings[dict_name], - f"ark,s,cs:{ali_paths[dict_name]}", - acc_paths[dict_name], - ], - stderr=log_file, - env=os.environ, - ) - acc_proc.communicate() - - -def acc_stats(aligner: Trainer): - """ - Multiprocessing function that accumulates stats for GMM training - - Parameters - ---------- - aligner : Trainer - Trainer - - Notes - ----- - See :kaldi_src:`gmmbin/gmm-acc-stats-ali` for more details on the Kaldi - binary, and :kaldi_steps:`train_mono` for an example Kaldi script - """ - arguments = [j.acc_stats_arguments(aligner) for j in aligner.corpus.jobs] - - if aligner.use_mp: - run_mp(acc_stats_func, arguments, aligner.working_log_directory) - else: - run_non_mp(acc_stats_func, arguments, aligner.working_log_directory) - - log_path = os.path.join(aligner.working_log_directory, f"update.{aligner.iteration}.log") - with open(log_path, "w") as log_file: - acc_files = [] - for a in arguments: - acc_files.extend(a.acc_paths.values()) - sum_proc = subprocess.Popen( - [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est"), - f"--write-occs={aligner.next_occs_path}", - f"--mix-up={aligner.current_gaussians}", - f"--power={aligner.power}", - aligner.current_model_path, - "-", - aligner.next_model_path, - ], - stdin=sum_proc.stdout, - stderr=log_file, - env=os.environ, - ) - est_proc.communicate() - avg_like_pattern = re.compile( - r"Overall avg like per frame \(Gaussian only\) = (?P[-.,\d]+) over (?P[.\d+e]) frames" - ) - average_logdet_pattern = re.compile( - r"Overall average logdet is (?P[-.,\d]+) over (?P[.\d+e]) frames" - ) - avg_like_sum = 0 - avg_like_frames = 0 - average_logdet_sum = 0 - average_logdet_frames = 0 - for a in arguments: - with open(a.log_path, "r", encoding="utf8") as f: - for line in f: - m = re.search(avg_like_pattern, line) - if m: - like = float(m.group("like")) - frames = float(m.group("frames")) - avg_like_sum += like * frames - avg_like_frames += frames - m = re.search(average_logdet_pattern, line) - if m: - logdet = float(m.group("logdet")) - frames = float(m.group("frames")) - average_logdet_sum += logdet * frames - average_logdet_frames += frames - if avg_like_frames: - log_like = avg_like_sum / avg_like_frames - if average_logdet_frames: - log_like += average_logdet_sum / average_logdet_frames - aligner.logger.debug(f"Likelihood for iteration {aligner.iteration}: {log_like}") - - if not aligner.debug: - for f in acc_files: - os.remove(f) - - -def compile_train_graphs_func( - log_path: str, - dictionaries: List[str], - tree_path: str, - model_path: str, - text_int_paths: Dict[str, str], - disambig_paths: Dict[str, str], - lexicon_fst_paths: Dict[str, str], - fst_scp_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function to compile training graphs - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - tree_path: str - Path to the acoustic model tree file - model_path: str - Path to the acoustic model file - text_int_paths: Dict[str, str] - PronunciationDictionary of text int files per dictionary name - disambig_paths: Dict[str, str] - PronunciationDictionary of disambiguation symbol int files per dictionary name - lexicon_fst_paths: Dict[str, str] - PronunciationDictionary of L.fst files per dictionary name - fst_scp_paths: Dict[str, str] - PronunciationDictionary of utterance FST scp files per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - disambig_path = disambig_paths[dict_name] - fst_scp_path = fst_scp_paths[dict_name] - fst_ark_path = fst_scp_path.replace(".scp", ".ark") - text_path = text_int_paths[dict_name] - proc = subprocess.Popen( - [ - thirdparty_binary("compile-train-graphs"), - f"--read-disambig-syms={disambig_path}", - tree_path, - model_path, - lexicon_fst_paths[dict_name], - f"ark:{text_path}", - f"ark,scp:{fst_ark_path},{fst_scp_path}", - ], - stderr=log_file, - env=os.environ, - ) - proc.communicate() - - -def compile_train_graphs(aligner: Union[BaseAligner, BaseTrainer]) -> None: - """ - Multiprocessing function that compiles training graphs for utterances - - See http://kaldi-asr.org/doc/compile-train-graphs_8cc.html for more details - on the Kaldi binary this function calls. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_mono.sh - for the bash script that this function was extracted from. - - Parameters - ---------- - aligner: Aligner - Aligner - """ - aligner.logger.debug("Compiling training graphs...") - begin = time.time() - log_directory = aligner.working_log_directory - os.makedirs(log_directory, exist_ok=True) - jobs = [x.compile_train_graph_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(compile_train_graphs_func, jobs, log_directory) - else: - run_non_mp(compile_train_graphs_func, jobs, log_directory) - aligner.logger.debug(f"Compiling training graphs took {time.time() - begin}") - - -def mono_align_equal_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - fst_scp_paths: Dict[str, str], - ali_ark_paths: Dict[str, str], - acc_paths: Dict[str, str], - model_path: str, -): - """ - Multiprocessing function for initializing monophone alignments - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - fst_scp_paths: Dict[str, str] - PronunciationDictionary of utterance FST scp files per dictionary name - ali_ark_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - acc_paths: Dict[str, str] - PronunciationDictionary of accumulated stats files per dictionary name - model_path: str - Path to the acoustic model file - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - fst_path = fst_scp_paths[dict_name] - ali_path = ali_ark_paths[dict_name] - acc_path = acc_paths[dict_name] - align_proc = subprocess.Popen( - [ - thirdparty_binary("align-equal-compiled"), - f"scp:{fst_path}", - feature_strings[dict_name], - f"ark:{ali_path}", - ], - stderr=log_file, - env=os.environ, - ) - align_proc.communicate() - stats_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-acc-stats-ali"), - "--binary=true", - model_path, - feature_strings[dict_name], - f"ark:{ali_path}", - acc_path, - ], - stdin=align_proc.stdout, - stderr=log_file, - env=os.environ, - ) - stats_proc.communicate() - - -def mono_align_equal(aligner: MonophoneTrainer): - """ - Multiprocessing function that creates equal alignments for base monophone training - - See http://kaldi-asr.org/doc/align-equal-compiled_8cc.html for more details - on the Kaldi binary this function calls. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_mono.sh - for the bash script that this function was extracted from. - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.MonophoneTrainer` - Monophone trainer - """ - - arguments = [x.mono_align_equal_arguments(aligner) for x in aligner.corpus.jobs] - - if aligner.use_mp: - run_mp(mono_align_equal_func, arguments, aligner.log_directory) - else: - run_non_mp(mono_align_equal_func, arguments, aligner.log_directory) - - log_path = os.path.join(aligner.working_log_directory, "update.0.log") - with open(log_path, "w") as log_file: - acc_files = [] - for x in arguments: - acc_files.extend(sorted(x.acc_paths.values())) - sum_proc = subprocess.Popen( - [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est"), - "--min-gaussian-occupancy=3", - f"--mix-up={aligner.current_gaussians}", - f"--power={aligner.power}", - aligner.current_model_path, - "-", - aligner.next_model_path, - ], - stderr=log_file, - stdin=sum_proc.stdout, - env=os.environ, - ) - est_proc.communicate() - if not aligner.debug: - for f in acc_files: - os.remove(f) - - -def align_func( - log_path: str, - dictionaries: List[str], - fst_scp_paths: Dict[str, str], - feature_strings: Dict[str, str], - model_path: str, - ali_paths: Dict[str, str], - score_paths: Dict[str, str], - loglike_paths: Dict[str, str], - align_options: MetaDict, -): - """ - Multiprocessing function for alignment - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - fst_scp_paths: Dict[str, str] - PronunciationDictionary of FST scp file paths per dictionary name - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - model_path: str - Path to the acoustic model file - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - score_paths: Dict[str, str] - PronunciationDictionary of scores files per dictionary name - loglike_paths: Dict[str, str] - PronunciationDictionary of log likelihood files per dictionary name - align_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for alignment - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - fst_path = fst_scp_paths[dict_name] - ali_path = ali_paths[dict_name] - com = [ - thirdparty_binary("gmm-align-compiled"), - f"--transition-scale={align_options['transition_scale']}", - f"--acoustic-scale={align_options['acoustic_scale']}", - f"--self-loop-scale={align_options['self_loop_scale']}", - f"--beam={align_options['beam']}", - f"--retry-beam={align_options['retry_beam']}", - "--careful=false", - "-", - f"scp:{fst_path}", - feature_string, - f"ark:{ali_path}", - ] - if align_options["debug"]: - loglike_path = loglike_paths[dict_name] - score_path = score_paths[dict_name] - com.insert(1, f"--write-per-frame-acoustic-loglikes=ark,t:{loglike_path}") - com.append(f"ark,t:{score_path}") - - boost_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-boost-silence"), - f"--boost={align_options['boost_silence']}", - align_options["optional_silence_csl"], - model_path, - "-", - ], - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - align_proc = subprocess.Popen( - com, stderr=log_file, stdin=boost_proc.stdout, env=os.environ - ) - align_proc.communicate() - - -def align(aligner: Union[BaseAligner, BaseTrainer]) -> None: - """ - Multiprocessing function that aligns based on the current model - - See http://kaldi-asr.org/doc/gmm-align-compiled_8cc.html and - http://kaldi-asr.org/doc/gmm-boost-silence_8cc.html for more details - on the Kaldi binary this function calls. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/align_si.sh - for the bash script this function was based on. - - Parameters - ---------- - aligner: Aligner - Aligner - """ - begin = time.time() - log_directory = aligner.working_log_directory - - arguments = [x.align_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(align_func, arguments, log_directory) - else: - run_non_mp(align_func, arguments, log_directory) - - error_logs = [] - for j in arguments: - - with open(j.log_path, "r", encoding="utf8") as f: - for line in f: - if line.strip().startswith("ERROR"): - error_logs.append(j.log_path) - break - if error_logs: - raise AlignmentError(error_logs) - aligner.logger.debug(f"Alignment round took {time.time() - begin}") - - -def compile_information_func(align_log_path: str) -> Dict[str, Union[List[str], float, int]]: - """ - Multiprocessing function for compiling information about alignment - - Parameters - ---------- - align_log_path: str - Log path for alignment - - Returns - ------- - Dict - Information about log-likelihood and number of unaligned files - """ - average_logdet_pattern = re.compile( - r"Overall average logdet is (?P[-.,\d]+) over (?P[.\d+e]+) frames" - ) - log_like_pattern = re.compile( - r"^LOG .* Overall log-likelihood per frame is (?P[-0-9.]+) over (?P\d+) frames.*$" - ) - - decode_error_pattern = re.compile( - r"^WARNING .* Did not successfully decode file (?P.*?), .*$" - ) - - data = {"unaligned": [], "too_short": [], "log_like": 0, "total_frames": 0} - with open(align_log_path, "r", encoding="utf8") as f: - for line in f: - decode_error_match = re.match(decode_error_pattern, line) - if decode_error_match: - data["unaligned"].append(decode_error_match.group("utt")) - continue - log_like_match = re.match(log_like_pattern, line) - if log_like_match: - log_like = log_like_match.group("log_like") - frames = log_like_match.group("frames") - data["log_like"] = float(log_like) - data["total_frames"] = int(frames) - m = re.search(average_logdet_pattern, line) - if m: - logdet = float(m.group("logdet")) - frames = float(m.group("frames")) - data["logdet"] = logdet - data["logdet_frames"] = frames - return data - - -def compile_information(aligner: Union[BaseAligner, BaseTrainer]) -> Tuple[Dict[str, str], float]: - """ - Compiles information about alignment, namely what the overall log-likelihood was - and how many files were unaligned - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - Dict - Unaligned files - float - Log-likelihood of alignment - """ - compile_info_begin = time.time() - - jobs = [x.compile_information_arguments(aligner) for x in aligner.corpus.jobs] - - if aligner.use_mp: - alignment_info = run_mp( - compile_information_func, jobs, aligner.working_log_directory, True - ) - else: - alignment_info = run_non_mp( - compile_information_func, jobs, aligner.working_log_directory, True - ) - - unaligned = {} - avg_like_sum = 0 - avg_like_frames = 0 - average_logdet_sum = 0 - average_logdet_frames = 0 - for data in alignment_info.values(): - avg_like_frames += data["total_frames"] - avg_like_sum += data["log_like"] * data["total_frames"] - if "logdet_frames" in data: - average_logdet_frames += data["logdet_frames"] - average_logdet_sum += data["logdet"] * data["logdet_frames"] - for u in data["unaligned"]: - unaligned[u] = "Beam too narrow" - for u in data["too_short"]: - unaligned[u] = "Segment too short" - - if not avg_like_frames: - aligner.logger.warning( - "No files were aligned, this likely indicates serious problems with the aligner." - ) - aligner.logger.debug(f"Compiling information took {time.time() - compile_info_begin}") - log_like = avg_like_sum / avg_like_frames - if average_logdet_sum: - log_like += average_logdet_sum / average_logdet_frames - return unaligned, log_like - - -def compute_alignment_improvement_func( - log_path: str, - dictionaries: List[str], - model_path: str, - text_int_paths: Dict[str, str], - word_boundary_paths: Dict[str, str], - ali_paths: Dict[str, str], - frame_shift: int, - reversed_phone_mappings: Dict[str, Dict[int, str]], - positions: Dict[str, List[str]], - phone_ctm_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function for computing alignment improvement over training - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - model_path: str - Path to the acoustic model file - text_int_paths: Dict[str, str] - PronunciationDictionary of text int files per dictionary name - word_boundary_paths: Dict[str, str] - PronunciationDictionary of word boundary files per dictionary name - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - frame_shift: int - Frame shift of feature generation, in ms - reversed_phone_mappings: Dict[str, Dict[int, str]] - Mapping of phone IDs to phone labels per dictionary name - positions: Dict[str, List[str]] - Positions per dictionary name - phone_ctm_paths: Dict[str, str] - PronunciationDictionary of phone ctm files per dictionary name - """ - try: - - frame_shift = frame_shift / 1000 - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - text_int_path = text_int_paths[dict_name] - ali_path = ali_paths[dict_name] - phone_ctm_path = phone_ctm_paths[dict_name] - word_boundary_path = word_boundary_paths[dict_name] - if os.path.exists(phone_ctm_path): - continue - - lin_proc = subprocess.Popen( - [ - thirdparty_binary("linear-to-nbest"), - f"ark:{ali_path}", - f"ark:{text_int_path}", - "", - "", - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - det_proc = subprocess.Popen( - [thirdparty_binary("lattice-determinize-pruned"), "ark:-", "ark:-"], - stdin=lin_proc.stdout, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - align_proc = subprocess.Popen( - [ - thirdparty_binary("lattice-align-words"), - word_boundary_path, - model_path, - "ark:-", - "ark:-", - ], - stdin=det_proc.stdout, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - phone_proc = subprocess.Popen( - [thirdparty_binary("lattice-to-phone-lattice"), model_path, "ark:-", "ark:-"], - stdin=align_proc.stdout, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - nbest_proc = subprocess.Popen( - [ - thirdparty_binary("nbest-to-ctm"), - f"--frame-shift={frame_shift}", - "ark:-", - phone_ctm_path, - ], - stdin=phone_proc.stdout, - stderr=log_file, - env=os.environ, - ) - nbest_proc.communicate() - mapping = reversed_phone_mappings[dict_name] - actual_lines = [] - with open(phone_ctm_path, "r", encoding="utf8") as f: - for line in f: - line = line.strip() - if line == "": - continue - line = line.split(" ") - utt = line[0] - begin = float(line[2]) - duration = float(line[3]) - end = begin + duration - label = line[4] - try: - label = mapping[int(label)] - except KeyError: - pass - for p in positions[dict_name]: - if label.endswith(p): - label = label[: -1 * len(p)] - actual_lines.append([utt, begin, end, label]) - with open(phone_ctm_path, "w", encoding="utf8") as f: - for line in actual_lines: - f.write(f"{' '.join(map(str, line))}\n") - except Exception as e: - raise (Exception(str(e))) - - -def parse_iteration_alignments( - aligner: Trainer, iteration: Optional[int] = None -) -> Dict[str, List[Tuple[float, float, str]]]: - """ - Function to parse phone CTMs in a given iteration - - Parameters - ---------- - aligner: Trainer - Aligner - iteration: int - Iteration to compute over - Returns - ------- - Dict - Per utterance CtmIntervals - """ - if iteration is None: - iteration = aligner.iteration - data = {} - for j in aligner.corpus.jobs: - phone_ctm_path = os.path.join(aligner.working_directory, f"phone.{iteration}.{j.name}.ctm") - with open(phone_ctm_path, "r", encoding="utf8") as f: - for line in f: - line = line.strip() - if line == "": - continue - line = line.split(" ") - utt = line[0] - begin = float(line[1]) - end = float(line[2]) - label = line[3] - if utt not in data: - data[utt] = [] - data[utt].append((begin, end, label)) - return data - - -def compare_alignments( - alignments_one: Dict[str, List[Tuple[float, float, str]]], - alignments_two: Dict[str, List[Tuple[float, float, str]]], - frame_shift: int, -) -> Tuple[int, Optional[float]]: - """ - Compares two sets of alignments for difference - - Parameters - ---------- - alignments_one: Dict - First set of alignments - alignments_two: Dict - Second set of alignments - frame_shift: int - Frame shift in feature generation, in ms - - Returns - ------- - int - Difference in number of aligned files - float - Mean boundary difference between the two alignments - """ - utterances_aligned_diff = len(alignments_two) - len(alignments_one) - utts_one = set(alignments_one.keys()) - utts_two = set(alignments_two.keys()) - common_utts = utts_one.intersection(utts_two) - differences = [] - for u in common_utts: - end = alignments_one[u][-1][1] - t = 0 - one_alignment = alignments_one[u] - two_alignment = alignments_two[u] - difference = 0 - while t < end: - one_label = None - two_label = None - for b, e, l in one_alignment: - if t < b: - continue - if t >= e: - break - one_label = l - for b, e, l in two_alignment: - if t < b: - continue - if t >= e: - break - two_label = l - if one_label != two_label: - difference += frame_shift - t += frame_shift - difference /= end - differences.append(difference) - if differences: - mean_difference = statistics.mean(differences) - else: - mean_difference = None - return utterances_aligned_diff, mean_difference - - -def compute_alignment_improvement(aligner: Union[BaseAligner, BaseTrainer]) -> None: - """ - Computes aligner improvements in terms of number of aligned files and phone boundaries - for debugging purposes - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.BaseTrainer` or :class:`~montreal_forced_aligner.aligner.BaseAligner` - Aligner - """ - jobs = [x.alignment_improvement_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(compute_alignment_improvement_func, jobs, aligner.working_log_directory) - else: - run_non_mp(compute_alignment_improvement_func, jobs, aligner.working_log_directory) - - alignment_diff_path = os.path.join(aligner.working_directory, "train_change.csv") - if aligner.iteration == 0 or aligner.iteration not in aligner.realignment_iterations: - return - ind = aligner.realignment_iterations.index(aligner.iteration) - if ind != 0: - previous_iteration = aligner.realignment_iterations[ind - 1] - else: - previous_iteration = 0 - try: - previous_alignments = parse_iteration_alignments(aligner, previous_iteration) - except FileNotFoundError: - return - current_alignments = parse_iteration_alignments(aligner) - utterance_aligned_diff, mean_difference = compare_alignments( - previous_alignments, current_alignments, aligner.feature_config.frame_shift - ) - if not os.path.exists(alignment_diff_path): - with open(alignment_diff_path, "w", encoding="utf8") as f: - f.write( - "iteration,number_aligned,number_previously_aligned," - "difference_in_utts_aligned,mean_boundary_change\n" - ) - if aligner.iteration in aligner.realignment_iterations: - with open(alignment_diff_path, "a", encoding="utf8") as f: - f.write( - f"{aligner.iteration},{len(current_alignments)},{len(previous_alignments)}," - f"{utterance_aligned_diff},{mean_difference}\n" - ) - if not aligner.debug: - for j in jobs: - for p in j.phone_ctm_paths: - os.remove(p) - - -def ali_to_ctm_func( - log_path: str, - dictionaries: List[str], - ali_paths: Dict[str, str], - text_int_paths: Dict[str, str], - word_boundary_int_paths: Dict[str, str], - frame_shift: float, - model_path: str, - ctm_paths: Dict[str, str], - word_mode: bool, -) -> None: - """ - Multiprocessing function to convert alignment archives into CTM files - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - text_int_paths: Dict[str, str] - PronunciationDictionary of text int files per dictionary name - word_boundary_int_paths: Dict[str, str] - PronunciationDictionary of word boundary int files per dictionary name - frame_shift: float - Frame shift of feature generation in seconds - model_path: str - Path to the acoustic model file - ctm_paths: Dict[str, str] - PronunciationDictionary of CTM files per dictionary name - word_mode: bool - Flag for whether to parse words or phones - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - ali_path = ali_paths[dict_name] - text_int_path = text_int_paths[dict_name] - ctm_path = ctm_paths[dict_name] - word_boundary_int_path = word_boundary_int_paths[dict_name] - if os.path.exists(ctm_path): - return - lin_proc = subprocess.Popen( - [ - thirdparty_binary("linear-to-nbest"), - "ark:" + ali_path, - "ark:" + text_int_path, - "", - "", - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - align_words_proc = subprocess.Popen( - [ - thirdparty_binary("lattice-align-words"), - word_boundary_int_path, - model_path, - "ark:-", - "ark:-", - ], - stdin=lin_proc.stdout, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - if word_mode: - nbest_proc = subprocess.Popen( - [ - thirdparty_binary("nbest-to-ctm"), - f"--frame-shift={frame_shift}", - "ark:-", - ctm_path, - ], - stderr=log_file, - stdin=align_words_proc.stdout, - env=os.environ, - ) - else: - phone_proc = subprocess.Popen( - [thirdparty_binary("lattice-to-phone-lattice"), model_path, "ark:-", "ark:-"], - stdout=subprocess.PIPE, - stdin=align_words_proc.stdout, - stderr=log_file, - env=os.environ, - ) - nbest_proc = subprocess.Popen( - [ - thirdparty_binary("nbest-to-ctm"), - f"--frame-shift={frame_shift}", - "ark:-", - ctm_path, - ], - stdin=phone_proc.stdout, - stderr=log_file, - env=os.environ, - ) - nbest_proc.communicate() - - -class NoCleanupWordCtmProcessWorker(mp.Process): - """ - Multiprocessing worker for loading word CTM files without any clean up - - Parameters - ---------- - job_name: int - Job name - to_process_queue: :class:`~multiprocessing.Queue` - Return queue of jobs for later workers to process - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for processing - error_catching: CtmErrorDict - PronunciationDictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.multiprocessing.classes.NoCleanupWordCtmArguments` - Arguments to pass to the CTM processing function - """ - - def __init__( - self, - job_name: int, - to_process_queue: mp.Queue, - stopped: Stopped, - error_catching: CtmErrorDict, - arguments: NoCleanupWordCtmArguments, - ): - mp.Process.__init__(self) - self.job_name = job_name - self.dictionaries = arguments.dictionaries - self.ctm_paths = arguments.ctm_paths - self.to_process_queue = to_process_queue - self.stopped = stopped - self.error_catching = error_catching - - # Corpus information - self.utterances = arguments.utterances - - # PronunciationDictionary information - self.dictionary_data = arguments.dictionary_data - - def run(self) -> None: - """ - Run the word processing with no clean up - """ - current_file_data = {} - - def process_current(cur_utt: Utterance, current_labels: CtmType): - """Process current stack of intervals""" - actual_labels = parse_from_word_no_cleanup( - current_labels, self.dictionary_data[dict_name].reversed_words_mapping - ) - current_file_data[cur_utt.name] = actual_labels - - def process_current_file(cur_file: str): - """Process current file and add to return queue""" - self.to_process_queue.put(("word", cur_file, current_file_data)) - - cur_utt = None - cur_file = None - utt_begin = 0 - current_labels = [] - try: - for dict_name in self.dictionaries: - with open(self.ctm_paths[dict_name], "r") as word_file: - for line in word_file: - line = line.strip() - if not line: - continue - interval = process_ctm_line(line) - utt = interval.utterance - if cur_utt is None: - cur_utt = self.utterances[dict_name][utt] - utt_begin = cur_utt.begin - cur_file = cur_utt.file_name - - if utt != cur_utt: - process_current(cur_utt, current_labels) - cur_utt = self.utterances[dict_name][utt] - file_name = cur_utt.file_name - if file_name != cur_file: - process_current_file(cur_file) - current_file_data = {} - cur_file = file_name - current_labels = [] - if utt_begin: - interval.shift_times(utt_begin) - current_labels.append(interval) - if current_labels: - process_current(cur_utt, current_labels) - process_current_file(cur_file) - except Exception: - self.stopped.stop() - exc_type, exc_value, exc_traceback = sys.exc_info() - self.error_catching[("word", self.job_name)] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - - -class CleanupWordCtmProcessWorker(mp.Process): - """ - Multiprocessing worker for loading word CTM files with cleaning up MFA-internal modifications - - Parameters - ---------- - job_name: int - Job name - to_process_queue: :class:`~multiprocessing.Queue` - Return queue of jobs for later workers to process - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for processing - error_catching: CtmErrorDict - PronunciationDictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.multiprocessing.classes.CleanupWordCtmArguments` - Arguments to pass to the CTM processing function - """ - - def __init__( - self, - job_name: int, - to_process_queue: mp.Queue, - stopped: Stopped, - error_catching: CtmErrorDict, - arguments: CleanupWordCtmArguments, - ): - mp.Process.__init__(self) - self.job_name = job_name - self.dictionaries = arguments.dictionaries - self.ctm_paths = arguments.ctm_paths - self.to_process_queue = to_process_queue - self.stopped = stopped - self.error_catching = error_catching - - # Corpus information - self.utterances = arguments.utterances - - # PronunciationDictionary information - self.dictionary_data = arguments.dictionary_data - - def run(self) -> None: - """ - Run the word processing with clean up - """ - current_file_data = {} - - def process_current(cur_utt: Utterance, current_labels: CtmType) -> None: - """Process current stack of intervals""" - text = cur_utt.text.split() - actual_labels = parse_from_word(current_labels, text, self.dictionary_data[dict_name]) - - current_file_data[cur_utt.name] = actual_labels - - def process_current_file(cur_file: str) -> None: - """Process current file and add to return queue""" - self.to_process_queue.put(("word", cur_file, current_file_data)) - - cur_utt = None - cur_file = None - utt_begin = 0 - current_labels = [] - try: - for dict_name in self.dictionaries: - ctm_path = self.ctm_paths[dict_name] - with open(ctm_path, "r") as word_file: - for line in word_file: - line = line.strip() - if not line: - continue - interval = process_ctm_line(line) - utt = interval.utterance - if cur_utt is None: - cur_utt = self.utterances[dict_name][utt] - utt_begin = cur_utt.begin - cur_file = cur_utt.file_name - - if utt != cur_utt: - process_current(cur_utt, current_labels) - cur_utt = self.utterances[dict_name][utt] - utt_begin = cur_utt.begin - file_name = cur_utt.file_name - if file_name != cur_file: - process_current_file(cur_file) - current_file_data = {} - cur_file = file_name - current_labels = [] - if utt_begin: - interval.shift_times(utt_begin) - current_labels.append(interval) - if current_labels: - process_current(cur_utt, current_labels) - process_current_file(cur_file) - except Exception: - self.stopped.stop() - exc_type, exc_value, exc_traceback = sys.exc_info() - self.error_catching[("word", self.job_name)] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - - -class PhoneCtmProcessWorker(mp.Process): - """ - Multiprocessing worker for loading phone CTM files - - Parameters - ---------- - job_name: int - Job name - to_process_queue: :class:`~multiprocessing.Queue` - Return queue of jobs for later workers to process - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for processing - error_catching: CtmErrorDict - PronunciationDictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.multiprocessing.classes.PhoneCtmArguments` - Arguments to pass to the CTM processing function - """ - - def __init__( - self, - job_name: int, - to_process_queue: mp.Queue, - stopped: Stopped, - error_catching: CtmErrorDict, - arguments: PhoneCtmArguments, - ): - mp.Process.__init__(self) - self.job_name = job_name - self.dictionaries = arguments.dictionaries - self.ctm_paths = arguments.ctm_paths - self.to_process_queue = to_process_queue - self.stopped = stopped - self.error_catching = error_catching - - self.utterances = arguments.utterances - - self.reversed_phone_mappings = arguments.reversed_phone_mappings - self.positions = arguments.positions - - def run(self) -> None: - """Run the phone processing""" - cur_utt = None - cur_file = None - utt_begin = 0 - current_labels = [] - - current_file_data = {} - - def process_current_utt(cur_utt: Utterance, current_labels: CtmType) -> None: - """Process current stack of intervals""" - actual_labels = parse_from_phone( - current_labels, self.reversed_phone_mappings[dict_name], self.positions[dict_name] - ) - current_file_data[cur_utt.name] = actual_labels - - def process_current_file(cur_file: str) -> None: - """Process current file and add to return queue""" - self.to_process_queue.put(("phone", cur_file, current_file_data)) - - try: - for dict_name in self.dictionaries: - with open(self.ctm_paths[dict_name], "r") as word_file: - for line in word_file: - line = line.strip() - if not line: - continue - interval = process_ctm_line(line) - utt = interval.utterance - if cur_utt is None: - cur_utt = self.utterances[dict_name][utt] - cur_file = cur_utt.file_name - utt_begin = cur_utt.begin - - if utt != cur_utt: - - process_current_utt(cur_utt, current_labels) - - cur_utt = self.utterances[dict_name][utt] - file_name = cur_utt.file_name - utt_begin = cur_utt.begin - - if file_name != cur_file: - process_current_file(cur_file) - current_file_data = {} - cur_file = file_name - current_labels = [] - if utt_begin: - interval.shift_times(utt_begin) - current_labels.append(interval) - if current_labels: - process_current_utt(cur_utt, current_labels) - process_current_file(cur_file) - except Exception: - self.stopped.stop() - exc_type, exc_value, exc_traceback = sys.exc_info() - self.error_catching[("phone", self.job_name)] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - - -class CombineProcessWorker(mp.Process): - """ - Multiprocessing worker for loading phone CTM files - - Parameters - ---------- - job_name: int - Job name - to_process_queue: :class:`~multiprocessing.Queue` - Input queue of phone and word ctms to combine - to_export_queue: :class:`~multiprocessing.Queue` - Export queue of combined CTMs - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for processing - finished_combining: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Signal that this worker has finished combining all CTMs - error_catching: CtmErrorDict - PronunciationDictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.multiprocessing.classes.CombineCtmArguments` - Arguments to pass to the CTM combining function - """ - - def __init__( - self, - job_name: int, - to_process_queue: mp.Queue, - to_export_queue: mp.Queue, - stopped: Stopped, - finished_combining: Stopped, - error_catching: CtmErrorDict, - arguments: CombineCtmArguments, - ): - mp.Process.__init__(self) - self.job_name = job_name - self.to_process_queue = to_process_queue - self.to_export_queue = to_export_queue - self.stopped = stopped - self.finished_combining = finished_combining - self.error_catching = error_catching - - self.files = arguments.files - self.dictionary_data = arguments.dictionary_data - self.cleanup_textgrids = arguments.cleanup_textgrids - - def run(self) -> None: - """Run the combination function""" - sum_time = 0 - count_time = 0 - phone_data = {} - word_data = {} - while True: - try: - w_p, file_name, data = self.to_process_queue.get(timeout=queue_polling_timeout) - begin_time = time.time() - except Empty: - if self.finished_combining.stop_check(): - break - continue - self.to_process_queue.task_done() - if self.stopped.stop_check(): - continue - if w_p == "phone": - if file_name in word_data: - word_ctm = word_data.pop(file_name) - phone_ctm = data - else: - phone_data[file_name] = data - continue - else: - if file_name in phone_data: - phone_ctm = phone_data.pop(file_name) - word_ctm = data - else: - word_data[file_name] = data - continue - try: - file = self.files[file_name] - for u_name, u in file.utterances.items(): - if u_name not in word_ctm: - continue - u.word_labels = word_ctm[u_name] - u.phone_labels = phone_ctm[u_name] - data = generate_tiers(file, cleanup_textgrids=self.cleanup_textgrids) - self.to_export_queue.put((file_name, data)) - except Exception: - self.stopped.stop() - exc_type, exc_value, exc_traceback = sys.exc_info() - self.error_catching[("combining", self.job_name)] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - - sum_time += time.time() - begin_time - count_time += 1 - - -class ExportTextGridProcessWorker(mp.Process): - """ - Multiprocessing worker for exporting TextGrids - - Parameters - ---------- - for_write_queue: :class:`~multiprocessing.Queue` - Input queue of files to export - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for processing - finished_processing: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Input signal that all jobs have been added and no more new ones will come in - textgrid_errors: CtmErrorDict - PronunciationDictionary for storing errors encountered - arguments: :class:`~montreal_forced_aligner.multiprocessing.classes.ExportTextGridArguments` - Arguments to pass to the TextGrid export function - """ - - def __init__( - self, - for_write_queue: mp.Queue, - stopped: Stopped, - finished_processing: Stopped, - textgrid_errors: Dict[str, str], - arguments: ExportTextGridArguments, - ): - mp.Process.__init__(self) - self.for_write_queue = for_write_queue - self.stopped = stopped - self.finished_processing = finished_processing - self.textgrid_errors = textgrid_errors - - self.files = arguments.files - self.output_directory = arguments.output_directory - self.backup_output_directory = arguments.backup_output_directory - - self.frame_shift = arguments.frame_shift - - def run(self) -> None: - """Run the exporter function""" - while True: - try: - file_name, data = self.for_write_queue.get(timeout=queue_polling_timeout) - except Empty: - if self.finished_processing.stop_check(): - break - continue - self.for_write_queue.task_done() - if self.stopped.stop_check(): - continue - try: - overwrite = True - file = self.files[file_name] - output_path = file.construct_output_path( - self.output_directory, self.backup_output_directory - ) - - export_textgrid(file, output_path, data, self.frame_shift, overwrite) - except Exception: - exc_type, exc_value, exc_traceback = sys.exc_info() - self.textgrid_errors[file_name] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - - -class ExportPreparationProcessWorker(mp.Process): - """ - Multiprocessing worker for preparing CTMs for export - - Parameters - ---------- - to_export_queue: :class:`~multiprocessing.Queue` - Input queue of combined CTMs - for_write_queue: :class:`~multiprocessing.Queue` - Export queue of files to export - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for processing - finished_combining: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Input signal that all CTMs have been combined - files: Dict[str, File] - Files in corpus - """ - - def __init__( - self, - to_export_queue: mp.Queue, - for_write_queue: mp.Queue, - stopped: Stopped, - finished_combining: Stopped, - files: Dict[str, File], - ): - mp.Process.__init__(self) - self.to_export_queue = to_export_queue - self.for_write_queue = for_write_queue - self.stopped = stopped - self.finished_combining = finished_combining - - self.files = files - - def run(self) -> None: - """Run the export preparation worker""" - export_data = {} - try: - while True: - try: - file_name, data = self.to_export_queue.get(timeout=queue_polling_timeout) - except Empty: - if self.finished_combining.stop_check(): - break - continue - self.to_export_queue.task_done() - if self.stopped.stop_check(): - continue - file = self.files[file_name] - if len(file.speaker_ordering) > 1: - if file_name not in export_data: - export_data[file_name] = data - else: - export_data[file_name].update(data) - if len(export_data[file_name]) == len(file.speaker_ordering): - data = export_data.pop(file_name) - self.for_write_queue.put((file_name, data)) - else: - self.for_write_queue.put((file_name, data)) - - for k, v in export_data.items(): - self.for_write_queue.put((k, v)) - except Exception: - self.stopped.stop() - raise - - -def ctms_to_textgrids_mp(aligner: Aligner): - """ - Multiprocessing function for exporting alignment CTM information as TextGrids - - Parameters - ---------- - aligner: Aligner - Aligner - """ - export_begin = time.time() - manager = mp.Manager() - textgrid_errors = manager.dict() - error_catching = manager.dict() - stopped = Stopped() - backup_output_directory = None - if not aligner.align_config.overwrite: - backup_output_directory = os.path.join(aligner.align_directory, "textgrids") - os.makedirs(backup_output_directory, exist_ok=True) - - aligner.logger.debug("Beginning to process ctm files...") - ctm_begin_time = time.time() - word_procs = [] - phone_procs = [] - combine_procs = [] - finished_signals = [Stopped() for _ in range(aligner.corpus.num_jobs)] - finished_processing = Stopped() - to_process_queue = [mp.JoinableQueue() for _ in range(aligner.corpus.num_jobs)] - to_export_queue = mp.JoinableQueue() - for_write_queue = mp.JoinableQueue() - finished_combining = Stopped() - for j in aligner.corpus.jobs: - if aligner.align_config.cleanup_textgrids: - word_p = CleanupWordCtmProcessWorker( - j.name, - to_process_queue[j.name], - stopped, - error_catching, - j.cleanup_word_ctm_arguments(aligner), - ) - else: - word_p = NoCleanupWordCtmProcessWorker( - j.name, - to_process_queue[j.name], - stopped, - error_catching, - j.no_cleanup_word_ctm_arguments(aligner), - ) - - word_procs.append(word_p) - word_p.start() - - phone_p = PhoneCtmProcessWorker( - j.name, - to_process_queue[j.name], - stopped, - error_catching, - j.phone_ctm_arguments(aligner), - ) - phone_p.start() - phone_procs.append(phone_p) - - combine_p = CombineProcessWorker( - j.name, - to_process_queue[j.name], - to_export_queue, - stopped, - finished_signals[j.name], - error_catching, - j.combine_ctm_arguments(aligner), - ) - combine_p.start() - combine_procs.append(combine_p) - preparation_proc = ExportPreparationProcessWorker( - to_export_queue, for_write_queue, stopped, finished_combining, aligner.corpus.files - ) - preparation_proc.start() - - export_procs = [] - for j in aligner.corpus.jobs: - export_proc = ExportTextGridProcessWorker( - for_write_queue, - stopped, - finished_processing, - textgrid_errors, - j.export_textgrid_arguments(aligner), - ) - export_proc.start() - export_procs.append(export_proc) - - aligner.logger.debug("Waiting for processes to finish...") - for i in range(aligner.corpus.num_jobs): - word_procs[i].join() - phone_procs[i].join() - finished_signals[i].stop() - - aligner.logger.debug(f"Ctm parsers took {time.time() - ctm_begin_time} seconds") - - aligner.logger.debug("Waiting for processes to finish...") - for i in range(aligner.corpus.num_jobs): - to_process_queue[i].join() - combine_procs[i].join() - finished_combining.stop() - - to_export_queue.join() - preparation_proc.join() - - aligner.logger.debug(f"Combiners took {time.time() - ctm_begin_time} seconds") - aligner.logger.debug("Beginning export...") - - aligner.logger.debug(f"Adding jobs for export took {time.time() - export_begin}") - aligner.logger.debug("Waiting for export processes to join...") - - for_write_queue.join() - finished_processing.stop() - for i in range(aligner.corpus.num_jobs): - export_procs[i].join() - for_write_queue.join() - aligner.logger.debug(f"Export took {time.time() - export_begin} seconds") - - if error_catching: - aligner.logger.error("Error was encountered in processing CTMs") - for key, error in error_catching.items(): - aligner.logger.error(f"{key}:\n\n{error}") - raise AlignmentExportError(error_catching) - - if textgrid_errors: - aligner.logger.warning( - f"There were {len(textgrid_errors)} errors encountered in generating TextGrids. " - f"Check the output_errors.txt file in {os.path.join(aligner.textgrid_output)} " - f"for more details" - ) - output_textgrid_writing_errors(aligner.textgrid_output, textgrid_errors) - - -def convert_ali_to_textgrids(aligner: Aligner) -> None: - """ - Multiprocessing function that aligns based on the current model - - See: - - - http://kaldi-asr.org/doc/linear-to-nbest_8cc.html - - http://kaldi-asr.org/doc/lattice-align-words_8cc.html - - http://kaldi-asr.org/doc/lattice-to-phone-lattice_8cc.html - - http://kaldi-asr.org/doc/nbest-to-ctm_8cc.html - - for more details - on the Kaldi binaries this function calls. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/get_train_ctm.sh - for the bash script that this function was based on. - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.abc.Aligner` - Aligner - """ - log_directory = aligner.working_log_directory - os.makedirs(aligner.textgrid_output, exist_ok=True) - jobs = [x.ali_to_word_ctm_arguments(aligner) for x in aligner.corpus.jobs] # Word CTM jobs - jobs += [x.ali_to_phone_ctm_arguments(aligner) for x in aligner.corpus.jobs] # Phone CTM jobs - aligner.logger.info("Generating CTMs from alignment...") - if aligner.use_mp: - run_mp(ali_to_ctm_func, jobs, log_directory) - else: - run_non_mp(ali_to_ctm_func, jobs, log_directory) - aligner.logger.info("Finished generating CTMs!") - - aligner.logger.info("Exporting TextGrids from CTMs...") - if aligner.use_mp: - ctms_to_textgrids_mp(aligner) - else: - ctms_to_textgrids_non_mp(aligner) - aligner.logger.info("Finished exporting TextGrids!") - - -def tree_stats_func( - log_path: str, - dictionaries: List[str], - ci_phones: str, - model_path: str, - feature_strings: Dict[str, str], - ali_paths: Dict[str, str], - treeacc_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function for calculating tree stats for training - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - ci_phones: str - Colon-separated list of context-independent phones - model_path: str - Path to the acoustic model file - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - treeacc_paths: Dict[str, str] - PronunciationDictionary of accumulated tree stats files per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - ali_path = ali_paths[dict_name] - treeacc_path = treeacc_paths[dict_name] - subprocess.call( - [ - thirdparty_binary("acc-tree-stats"), - f"--ci-phones={ci_phones}", - model_path, - feature_string, - f"ark:{ali_path}", - treeacc_path, - ], - stderr=log_file, - ) - - -def tree_stats(trainer: Trainer) -> None: - """ - Multiprocessing function that computes stats for decision tree training - - See http://kaldi-asr.org/doc/acc-tree-stats_8cc.html for more details - on the Kaldi binary this runs. - - Parameters - ---------- - trainer: :class:`~montreal_forced_aligner.abc.Trainer` - Trainer - """ - - jobs = [j.tree_stats_arguments(trainer) for j in trainer.corpus.jobs] - - if trainer.use_mp: - run_mp(tree_stats_func, jobs, trainer.working_log_directory) - else: - run_non_mp(tree_stats_func, jobs, trainer.working_log_directory) - - tree_accs = [] - for x in jobs: - tree_accs.extend(x.treeacc_paths.values()) - log_path = os.path.join(trainer.working_log_directory, "sum_tree_acc.log") - with open(log_path, "w", encoding="utf8") as log_file: - subprocess.call( - [ - thirdparty_binary("sum-tree-stats"), - os.path.join(trainer.working_directory, "treeacc"), - ] - + tree_accs, - stderr=log_file, - ) - if not trainer.debug: - for f in tree_accs: - os.remove(f) - - -def convert_alignments_func( - log_path: str, - dictionaries: List[str], - model_path: str, - tree_path: str, - align_model_path: str, - ali_paths: Dict[str, str], - new_ali_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function for converting alignments from a previous trainer - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - model_path: str - Path to the acoustic model file - tree_path: str - Path to the acoustic model tree file - align_model_path: str - Path to the alignment acoustic model file - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - new_ali_paths: Dict[str, str] - PronunciationDictionary of new alignment archives per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - ali_path = ali_paths[dict_name] - new_ali_path = new_ali_paths[dict_name] - subprocess.call( - [ - thirdparty_binary("convert-ali"), - align_model_path, - model_path, - tree_path, - f"ark:{ali_path}", - f"ark:{new_ali_path}", - ], - stderr=log_file, - ) - - -def convert_alignments(trainer: Trainer) -> None: - """ - Multiprocessing function that converts alignments from previous training - - See http://kaldi-asr.org/doc/convert-ali_8cc.html for more details - on the Kaldi binary this runs. - - Parameters - ---------- - trainer: :class:`~montreal_forced_aligner.abc.Trainer` - Trainer - """ - - jobs = [x.convert_alignment_arguments(trainer) for x in trainer.corpus.jobs] - if trainer.use_mp: - run_mp(convert_alignments_func, jobs, trainer.working_log_directory) - else: - run_non_mp(convert_alignments_func, jobs, trainer.working_log_directory) - - -def calc_fmllr_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ali_paths: Dict[str, str], - ali_model_path: str, - model_path: str, - spk2utt_paths: Dict[str, str], - trans_paths: Dict[str, str], - fmllr_options: MetaDict, -) -> None: - """ - Multiprocessing function for calculating fMLLR transforms - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - ali_model_path: str - Path to the alignment acoustic model file - model_path: str - Path to the acoustic model file - spk2utt_paths: Dict[str, str] - PronunciationDictionary of spk2utt scps per dictionary name - trans_paths: Dict[str, str] - PronunciationDictionary of fMLLR transform archives per dictionary name - fmllr_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for fMLLR estimation - """ - with open(log_path, "w", encoding="utf8") as log_file: - log_file.writelines(f"{k}: {v}\n" for k, v in os.environ.items()) - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - ali_path = ali_paths[dict_name] - spk2utt_path = spk2utt_paths[dict_name] - trans_path = trans_paths[dict_name] - post_proc = subprocess.Popen( - [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - - weight_proc = subprocess.Popen( - [ - thirdparty_binary("weight-silence-post"), - "0.0", - fmllr_options["silence_csl"], - ali_model_path, - "ark:-", - "ark:-", - ], - stderr=log_file, - stdin=post_proc.stdout, - stdout=subprocess.PIPE, - env=os.environ, - ) - - if ali_model_path != model_path: - post_gpost_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-post-to-gpost"), - ali_model_path, - feature_string, - "ark:-", - "ark:-", - ], - stderr=log_file, - stdin=weight_proc.stdout, - stdout=subprocess.PIPE, - env=os.environ, - ) - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est-fmllr-gpost"), - "--verbose=4", - f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", - f"--spk2utt=ark:{spk2utt_path}", - model_path, - feature_string, - "ark,s,cs:-", - f"ark:{trans_path}", - ], - stderr=log_file, - stdin=post_gpost_proc.stdout, - env=os.environ, - ) - est_proc.communicate() - - else: - - if os.path.exists(trans_path): - cmp_trans_path = trans_paths[dict_name] + ".tmp" - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est-fmllr"), - "--verbose=4", - f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", - f"--spk2utt=ark:{spk2utt_path}", - model_path, - feature_string, - "ark:-", - "ark:-", - ], - stderr=log_file, - stdin=weight_proc.stdout, - stdout=subprocess.PIPE, - env=os.environ, - ) - comp_proc = subprocess.Popen( - [ - thirdparty_binary("compose-transforms"), - "--b-is-affine=true", - "ark:-", - f"ark:{trans_path}", - f"ark:{cmp_trans_path}", - ], - stderr=log_file, - stdin=est_proc.stdout, - env=os.environ, - ) - comp_proc.communicate() - - os.remove(trans_path) - os.rename(cmp_trans_path, trans_path) - else: - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est-fmllr"), - "--verbose=4", - f"--fmllr-update-type={fmllr_options['fmllr_update_type']}", - f"--spk2utt=ark:{spk2utt_path}", - model_path, - feature_string, - "ark,s,cs:-", - f"ark:{trans_path}", - ], - stderr=log_file, - stdin=weight_proc.stdout, - env=os.environ, - ) - est_proc.communicate() - - -def calc_fmllr(aligner: Aligner) -> None: - """ - Multiprocessing function that computes speaker adaptation (fMLLR) - - See: - - - http://kaldi-asr.org/doc/gmm-est-fmllr_8cc.html - - http://kaldi-asr.org/doc/ali-to-post_8cc.html - - http://kaldi-asr.org/doc/weight-silence-post_8cc.html - - http://kaldi-asr.org/doc/compose-transforms_8cc.html - - http://kaldi-asr.org/doc/transform-feats_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/align_fmllr.sh - for the original bash script that this function was based on. - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.abc.Aligner` - Aligner - """ - begin = time.time() - log_directory = aligner.working_log_directory - - jobs = [x.calc_fmllr_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(calc_fmllr_func, jobs, log_directory) - else: - run_non_mp(calc_fmllr_func, jobs, log_directory) - aligner.speaker_independent = False - aligner.logger.debug(f"Fmllr calculation took {time.time() - begin}") - - -def acc_stats_two_feats_func( - log_path: str, - dictionaries: List[str], - ali_paths: Dict[str, str], - acc_paths: Dict[str, str], - model_path: str, - feature_strings: Dict[str, str], - si_feature_strings: Dict[str, str], -) -> None: - """ - Multiprocessing function for accumulating stats across speaker-independent and - speaker-adapted features - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - acc_paths: Dict[str, str] - PronunciationDictionary of accumulated stats files per dictionary name - model_path: str - Path to the acoustic model file - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - si_feature_strings: Dict[str, str] - PronunciationDictionary of speaker-independent feature strings per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - ali_path = ali_paths[dict_name] - acc_path = acc_paths[dict_name] - feature_string = feature_strings[dict_name] - si_feature_string = si_feature_strings[dict_name] - ali_to_post_proc = subprocess.Popen( - [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - acc_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-acc-stats-twofeats"), - model_path, - feature_string, - si_feature_string, - "ark,s,cs:-", - acc_path, - ], - stderr=log_file, - stdin=ali_to_post_proc.stdout, - env=os.environ, - ) - acc_proc.communicate() - - -def create_align_model(aligner: SatTrainer) -> None: - """ - Create alignment model for speaker-adapted training that will use speaker-independent - features in later aligning - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.SatTrainer` - Aligner - """ - aligner.logger.info("Creating alignment model for speaker-independent features...") - begin = time.time() - log_directory = aligner.working_log_directory - - model_path = os.path.join(aligner.working_directory, "final.mdl") - align_model_path = os.path.join(aligner.working_directory, "final.alimdl") - arguments = [x.acc_stats_two_feats_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(acc_stats_two_feats_func, arguments, log_directory) - else: - run_non_mp(acc_stats_two_feats_func, arguments, log_directory) - - log_path = os.path.join(aligner.working_log_directory, "align_model_est.log") - with open(log_path, "w", encoding="utf8") as log_file: - - acc_files = [] - for x in arguments: - acc_files.extend(x.acc_paths.values()) - sum_proc = subprocess.Popen( - [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est"), - "--remove-low-count-gaussians=false", - f"--power={aligner.power}", - model_path, - "-", - align_model_path, - ], - stdin=sum_proc.stdout, - stderr=log_file, - env=os.environ, - ) - est_proc.communicate() - if not aligner.debug: - for f in acc_files: - os.remove(f) - - aligner.logger.debug(f"Alignment model creation took {time.time() - begin}") - - -def lda_acc_stats_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ali_paths: Dict[str, str], - model_path: str, - lda_options: MetaDict, - acc_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function to accumulate LDA stats - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - ali_paths: Dict[str, str] - Dictionary of alignment archives per dictionary name - model_path: str - Path to the acoustic model file - lda_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for LDA - acc_paths: Dict[str, str] - Dictionary of accumulated stats files per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - ali_path = ali_paths[dict_name] - feature_string = feature_strings[dict_name] - acc_path = acc_paths[dict_name] - ali_to_post_proc = subprocess.Popen( - [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - weight_silence_post_proc = subprocess.Popen( - [ - thirdparty_binary("weight-silence-post"), - f"{lda_options['boost_silence']}", - lda_options["silence_csl"], - model_path, - "ark:-", - "ark:-", - ], - stdin=ali_to_post_proc.stdout, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - acc_lda_post_proc = subprocess.Popen( - [ - thirdparty_binary("acc-lda"), - f"--rand-prune={lda_options['random_prune']}", - model_path, - feature_string, - "ark,s,cs:-", - acc_path, - ], - stdin=weight_silence_post_proc.stdout, - stderr=log_file, - env=os.environ, - ) - acc_lda_post_proc.communicate() - - -def lda_acc_stats(aligner: LdaTrainer) -> None: - """ - Multiprocessing function that accumulates LDA statistics - - See: - - - http://kaldi-asr.org/doc/ali-to-post_8cc.html - - http://kaldi-asr.org/doc/weight-silence-post_8cc.html - - http://kaldi-asr.org/doc/acc-lda_8cc.html - - http://kaldi-asr.org/doc/est-lda_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_lda_mllt.sh - for the original bash script that this function was based on. - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.LdaTrainer` - Trainer - """ - arguments = [x.lda_acc_stats_arguments(aligner) for x in aligner.corpus.jobs] - - if aligner.use_mp: - run_mp(lda_acc_stats_func, arguments, aligner.working_log_directory) - else: - run_non_mp(lda_acc_stats_func, arguments, aligner.working_log_directory) - - log_path = os.path.join(aligner.working_log_directory, "lda_est.log") - acc_list = [] - for x in arguments: - acc_list.extend(x.acc_paths.values()) - with open(log_path, "w", encoding="utf8") as log_file: - est_lda_proc = subprocess.Popen( - [ - thirdparty_binary("est-lda"), - f"--write-full-matrix={os.path.join(aligner.working_directory, 'full.mat')}", - f"--dim={aligner.lda_dimension}", - os.path.join(aligner.working_directory, "lda.mat"), - ] - + acc_list, - stderr=log_file, - env=os.environ, - ) - est_lda_proc.communicate() - - -def calc_lda_mllt_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ali_paths: Dict[str, str], - model_path: str, - lda_options: MetaDict, - macc_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function for estimating LDA with MLLT - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - Dictionary of feature strings per dictionary name - ali_paths: Dict[str, str] - Dictionary of alignment archives per dictionary name - model_path: str - Path to the acoustic model file - lda_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for LDA - macc_paths: Dict[str, str] - Dictionary of accumulated stats files per dictionary name - """ - # Estimating MLLT - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - ali_path = ali_paths[dict_name] - feature_string = feature_strings[dict_name] - macc_path = macc_paths[dict_name] - post_proc = subprocess.Popen( - [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - - weight_proc = subprocess.Popen( - [ - thirdparty_binary("weight-silence-post"), - "0.0", - lda_options["silence_csl"], - model_path, - "ark:-", - "ark:-", - ], - stdin=post_proc.stdout, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - acc_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-acc-mllt"), - f"--rand-prune={lda_options['random_prune']}", - model_path, - feature_string, - "ark,s,cs:-", - macc_path, - ], - stdin=weight_proc.stdout, - stderr=log_file, - env=os.environ, - ) - acc_proc.communicate() - - -def calc_lda_mllt(aligner: LdaTrainer) -> None: - """ - Multiprocessing function that calculates LDA+MLLT transformations - - See: - - - http://kaldi-asr.org/doc/ali-to-post_8cc.html - - http://kaldi-asr.org/doc/weight-silence-post_8cc.html - - http://kaldi-asr.org/doc/gmm-acc-mllt_8cc.html - - http://kaldi-asr.org/doc/est-mllt_8cc.html - - http://kaldi-asr.org/doc/gmm-transform-means_8cc.html - - http://kaldi-asr.org/doc/compose-transforms_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_lda_mllt.sh - for the original bash script that this function was based on. - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.LdaTrainer` - Trainer - """ - jobs = [x.calc_lda_mllt_arguments(aligner) for x in aligner.corpus.jobs] - - if aligner.use_mp: - run_mp(calc_lda_mllt_func, jobs, aligner.working_log_directory) - else: - run_non_mp(calc_lda_mllt_func, jobs, aligner.working_log_directory) - - log_path = os.path.join( - aligner.working_log_directory, f"transform_means.{aligner.iteration}.log" - ) - previous_mat_path = os.path.join(aligner.working_directory, "lda.mat") - new_mat_path = os.path.join(aligner.working_directory, "lda_new.mat") - composed_path = os.path.join(aligner.working_directory, "lda_composed.mat") - with open(log_path, "a", encoding="utf8") as log_file: - macc_list = [] - for x in jobs: - macc_list.extend(x.macc_paths.values()) - subprocess.call( - [thirdparty_binary("est-mllt"), new_mat_path] + macc_list, - stderr=log_file, - env=os.environ, - ) - subprocess.call( - [ - thirdparty_binary("gmm-transform-means"), - new_mat_path, - aligner.current_model_path, - aligner.current_model_path, - ], - stderr=log_file, - env=os.environ, - ) - - if os.path.exists(previous_mat_path): - subprocess.call( - [ - thirdparty_binary("compose-transforms"), - new_mat_path, - previous_mat_path, - composed_path, - ], - stderr=log_file, - env=os.environ, - ) - os.remove(previous_mat_path) - os.rename(composed_path, previous_mat_path) - else: - os.rename(new_mat_path, previous_mat_path) - - -def map_acc_stats_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - model_path: str, - ali_paths: Dict[str, str], - acc_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function for accumulating mapped stats for adapting acoustic models to new - domains - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - model_path: str - Path to the acoustic model file - ali_paths: Dict[str, str] - PronunciationDictionary of alignment archives per dictionary name - acc_paths: Dict[str, str] - PronunciationDictionary of accumulated stats files per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - acc_path = acc_paths[dict_name] - ali_path = ali_paths[dict_name] - acc_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-acc-stats-ali"), - model_path, - feature_string, - f"ark,s,cs:{ali_path}", - acc_path, - ], - stderr=log_file, - env=os.environ, - ) - acc_proc.communicate() - - -def train_map(aligner: AdaptingAligner) -> None: - """ - Trains an adapted acoustic model through mapping model states and update those with - enough data - - Source: https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_map.sh - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.aligner.AdaptingAligner` - Adapting aligner - """ - begin = time.time() - initial_mdl_path = os.path.join(aligner.working_directory, "0.mdl") - final_mdl_path = os.path.join(aligner.working_directory, "final.mdl") - log_directory = aligner.working_log_directory - os.makedirs(log_directory, exist_ok=True) - - jobs = [x.map_acc_stats_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(map_acc_stats_func, jobs, log_directory) - else: - run_non_mp(map_acc_stats_func, jobs, log_directory) - log_path = os.path.join(aligner.working_log_directory, "map_model_est.log") - occs_path = os.path.join(aligner.working_directory, "final.occs") - with open(log_path, "w", encoding="utf8") as log_file: - acc_files = [] - for j in jobs: - acc_files.extend(j.acc_paths.values()) - sum_proc = subprocess.Popen( - [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - ismooth_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-ismooth-stats"), - "--smooth-from-model", - f"--tau={aligner.mapping_tau}", - initial_mdl_path, - "-", - "-", - ], - stderr=log_file, - stdin=sum_proc.stdout, - stdout=subprocess.PIPE, - env=os.environ, - ) - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est"), - "--update-flags=m", - f"--write-occs={occs_path}", - "--remove-low-count-gaussians=false", - initial_mdl_path, - "-", - final_mdl_path, - ], - stdin=ismooth_proc.stdout, - stderr=log_file, - env=os.environ, - ) - est_proc.communicate() - - initial_alimdl_path = os.path.join(aligner.working_directory, "0.alimdl") - final_alimdl_path = os.path.join(aligner.working_directory, "0.alimdl") - if os.path.exists(initial_alimdl_path): - aligner.speaker_independent = True - jobs = [x.map_acc_stats_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.use_mp: - run_mp(map_acc_stats_func, jobs, log_directory) - else: - run_non_mp(map_acc_stats_func, jobs, log_directory) - - log_path = os.path.join(aligner.working_log_directory, "map_model_est.log") - with open(log_path, "w", encoding="utf8") as log_file: - acc_files = [] - for j in jobs: - acc_files.extend(j.acc_paths) - sum_proc = subprocess.Popen( - [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - ismooth_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-ismooth-stats"), - "--smooth-from-model", - f"--tau={aligner.mapping_tau}", - initial_alimdl_path, - "-", - "-", - ], - stderr=log_file, - stdin=sum_proc.stdout, - stdout=subprocess.PIPE, - env=os.environ, - ) - est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-est"), - "--update-flags=m", - "--remove-low-count-gaussians=false", - initial_alimdl_path, - "-", - final_alimdl_path, - ], - stdin=ismooth_proc.stdout, - stderr=log_file, - env=os.environ, - ) - est_proc.communicate() - - aligner.logger.debug(f"Mapping models took {time.time() - begin}") - - -def test_utterances_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - words_paths: Dict[str, str], - graphs_paths: Dict[str, str], - text_int_paths: Dict[str, str], - edits_paths: Dict[str, str], - out_int_paths: Dict[str, str], - model_path: str, -): - """ - Multiprocessing function to test utterance transcriptions - - Parameters - ---------- - log_path: str - Log path - dictionaries: List[str] - List of dictionaries - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - words_paths: Dict[str, str] - PronunciationDictionary of word mapping files per dictionary name - graphs_paths: Dict[str, str] - PronunciationDictionary of utterance FST graph archives per dictionary name - text_int_paths: Dict[str, str] - PronunciationDictionary of text.int files per dictionary name - edits_paths: Dict[str, str] - PronunciationDictionary of paths to save transcription differences per dictionary name - out_int_paths: Dict[str, str] - PronunciationDictionary of output .int files per dictionary name - model_path: str - Acoustic model path - """ - acoustic_scale = 0.1 - beam = 15.0 - lattice_beam = 8.0 - max_active = 750 - with open(log_path, "w") as log_file: - for dict_name in dictionaries: - words_path = words_paths[dict_name] - graphs_path = graphs_paths[dict_name] - feature_string = feature_strings[dict_name] - edits_path = edits_paths[dict_name] - text_int_path = text_int_paths[dict_name] - out_int_path = out_int_paths[dict_name] - latgen_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-latgen-faster"), - f"--acoustic-scale={acoustic_scale}", - f"--beam={beam}", - f"--max-active={max_active}", - f"--lattice-beam={lattice_beam}", - f"--word-symbol-table={words_path}", - model_path, - "ark:" + graphs_path, - feature_string, - "ark:-", - ], - stderr=log_file, - stdout=subprocess.PIPE, - ) - - oracle_proc = subprocess.Popen( - [ - thirdparty_binary("lattice-oracle"), - "ark:-", - f"ark,t:{text_int_path}", - f"ark,t:{out_int_path}", - f"ark,t:{edits_path}", - ], - stderr=log_file, - stdin=latgen_proc.stdout, - ) - oracle_proc.communicate() - - -def compile_utterance_train_graphs_func( - log_path: str, - dictionaries: List[str], - disambig_int_paths: Dict[str, str], - disambig_L_fst_paths: Dict[str, str], - fst_paths: Dict[str, str], - graphs_paths: Dict[str, str], - model_path: str, - tree_path: str, -): - """ - Multiprocessing function to compile utterance FSTs - - Parameters - ---------- - log_path: str - Log path - dictionaries: List[str] - List of dictionaries - disambig_int_paths: Dict[str, str] - PronunciationDictionary of disambiguation symbol int files per dictionary name - disambig_L_fst_paths: Dict[str, str] - PronunciationDictionary of disambiguation lexicon FSTs per dictionary name - fst_paths: Dict[str, str] - PronunciationDictionary of pregenerated utterance FST scp files per dictionary name - graphs_paths: Dict[str, str] - PronunciationDictionary of utterance FST graph archives per dictionary name - model_path: str - Acoustic model path - tree_path: str - Acoustic model's tree path - """ - with open(log_path, "w") as log_file: - for dict_name in dictionaries: - disambig_int_path = disambig_int_paths[dict_name] - disambig_L_fst_path = disambig_L_fst_paths[dict_name] - fst_path = fst_paths[dict_name] - graphs_path = graphs_paths[dict_name] - proc = subprocess.Popen( - [ - thirdparty_binary("compile-train-graphs-fsts"), - "--transition-scale=1.0", - "--self-loop-scale=0.1", - f"--read-disambig-syms={disambig_int_path}", - tree_path, - model_path, - disambig_L_fst_path, - f"ark:{fst_path}", - f"ark:{graphs_path}", - ], - stderr=log_file, - ) - - proc.communicate() diff --git a/montreal_forced_aligner/multiprocessing/classes.py b/montreal_forced_aligner/multiprocessing/classes.py deleted file mode 100644 index edf89703..00000000 --- a/montreal_forced_aligner/multiprocessing/classes.py +++ /dev/null @@ -1,2466 +0,0 @@ -""" -Multiprocessing classes ------------------------ - -""" -from __future__ import annotations - -import os -from typing import TYPE_CHECKING, Collection, Dict, List, NamedTuple, Optional, Set, Tuple - -if TYPE_CHECKING: - from ..corpus.classes import File, Speaker, Utterance - -from ..abc import IvectorExtractor, MetaDict, MfaWorker -from ..helper import output_mapping, save_scp - -if TYPE_CHECKING: - from ..abc import Aligner, MappingType, ReversedMappingType, WordsType - from ..aligner.adapting import AdaptingAligner - from ..aligner.base import BaseAligner - from ..config import FeatureConfig - from ..corpus import Corpus - from ..dictionary import DictionaryData - from ..segmenter import Segmenter - from ..trainers import ( - BaseTrainer, - IvectorExtractorTrainer, - LdaTrainer, - MonophoneTrainer, - SatTrainer, - ) - from ..transcriber import Transcriber - from ..validator import CorpusValidator - - -__all__ = [ - "Job", - "AlignArguments", - "VadArguments", - "SegmentVadArguments", - "CreateHclgArguments", - "AccGlobalStatsArguments", - "AccStatsArguments", - "AccIvectorStatsArguments", - "AccStatsTwoFeatsArguments", - "AliToCtmArguments", - "MfccArguments", - "ScoreArguments", - "DecodeArguments", - "PhoneCtmArguments", - "CombineCtmArguments", - "CleanupWordCtmArguments", - "NoCleanupWordCtmArguments", - "LmRescoreArguments", - "AlignmentImprovementArguments", - "ConvertAlignmentsArguments", - "CalcFmllrArguments", - "CalcLdaMlltArguments", - "GmmGselectArguments", - "FinalFmllrArguments", - "LatGenFmllrArguments", - "FmllrRescoreArguments", - "TreeStatsArguments", - "LdaAccStatsArguments", - "MapAccStatsArguments", - "GaussToPostArguments", - "InitialFmllrArguments", - "ExtractIvectorsArguments", - "ExportTextGridArguments", - "CompileTrainGraphsArguments", - "CompileInformationArguments", - "CompileUtteranceTrainGraphsArguments", - "MonoAlignEqualArguments", - "TestUtterancesArguments", - "CarpaLmRescoreArguments", - "GeneratePronunciationsArguments", -] - - -class VadArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.features.compute_vad_func`""" - - log_path: str - dictionaries: List[str] - feats_scp_paths: Dict[str, str] - vad_scp_paths: Dict[str, str] - vad_options: MetaDict - - -class MfccArguments(NamedTuple): - """ - Arguments for :func:`~montreal_forced_aligner.multiprocessing.features.mfcc_func` - """ - - log_path: str - dictionaries: List[str] - feats_scp_paths: Dict[str, str] - lengths_paths: Dict[str, str] - segment_paths: Dict[str, str] - wav_paths: Dict[str, str] - mfcc_options: MetaDict - - -class CompileTrainGraphsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compile_train_graphs_func`""" - - log_path: str - dictionaries: List[str] - tree_path: str - model_path: str - text_int_paths: Dict[str, str] - disambig_paths: Dict[str, str] - lexicon_fst_paths: Dict[str, str] - fst_scp_paths: Dict[str, str] - - -class MonoAlignEqualArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.mono_align_equal_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - fst_scp_paths: Dict[str, str] - ali_ark_paths: Dict[str, str] - acc_paths: Dict[str, str] - model_path: str - - -class AccStatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.acc_stats_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ali_paths: Dict[str, str] - acc_paths: Dict[str, str] - model_path: str - - -class AlignArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.align_func`""" - - log_path: str - dictionaries: List[str] - fst_scp_paths: Dict[str, str] - feature_strings: Dict[str, str] - model_path: str - ali_paths: Dict[str, str] - score_paths: Dict[str, str] - loglike_paths: Dict[str, str] - align_options: MetaDict - - -class CompileInformationArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compile_information_func`""" - - align_log_paths: str - - -class AliToCtmArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.ali_to_ctm_func`""" - - log_path: str - dictionaries: List[str] - ali_paths: Dict[str, str] - text_int_paths: Dict[str, str] - word_boundary_int_paths: Dict[str, str] - frame_shift: float - model_path: str - ctm_paths: Dict[str, str] - word_mode: bool - - -class CleanupWordCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.multiprocessing.alignment.CleanupWordCtmProcessWorker`""" - - ctm_paths: Dict[str, str] - dictionaries: List[str] - utterances: Dict[str, Dict[str, Utterance]] - dictionary_data: Dict[str, DictionaryData] - - -class NoCleanupWordCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.multiprocessing.alignment.NoCleanupWordCtmProcessWorker`""" - - ctm_paths: Dict[str, str] - dictionaries: List[str] - utterances: Dict[str, Dict[str, Utterance]] - dictionary_data: Dict[str, DictionaryData] - - -class PhoneCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.multiprocessing.alignment.PhoneCtmProcessWorker`""" - - ctm_paths: Dict[str, str] - dictionaries: List[str] - utterances: Dict[str, Dict[str, Utterance]] - reversed_phone_mappings: Dict[str, ReversedMappingType] - positions: Dict[str, List[str]] - - -class CombineCtmArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.multiprocessing.alignment.CombineProcessWorker`""" - - dictionaries: List[str] - files: Dict[str, File] - dictionary_data: Dict[str, DictionaryData] - cleanup_textgrids: bool - - -class ExportTextGridArguments(NamedTuple): - """Arguments for :class:`~montreal_forced_aligner.multiprocessing.alignment.ExportTextGridProcessWorker`""" - - files: Dict[str, File] - frame_shift: int - output_directory: str - backup_output_directory: str - - -class TreeStatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.tree_stats_func`""" - - log_path: str - dictionaries: List[str] - ci_phones: str - model_path: str - feature_strings: Dict[str, str] - ali_paths: Dict[str, str] - treeacc_paths: Dict[str, str] - - -class ConvertAlignmentsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.convert_alignments_func`""" - - log_path: str - dictionaries: List[str] - model_path: str - tree_path: str - align_model_path: str - ali_paths: Dict[str, str] - new_ali_paths: Dict[str, str] - - -class AlignmentImprovementArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compute_alignment_improvement_func`""" - - log_path: str - dictionaries: List[str] - model_path: str - text_int_paths: Dict[str, str] - word_boundary_paths: Dict[str, str] - ali_paths: Dict[str, str] - frame_shift: int - reversed_phone_mappings: Dict[str, Dict[int, str]] - positions: Dict[str, List[str]] - phone_ctm_paths: Dict[str, str] - - -class CalcFmllrArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.calc_fmllr_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ali_paths: Dict[str, str] - ali_model_path: str - model_path: str - spk2utt_paths: Dict[str, str] - trans_paths: Dict[str, str] - fmllr_options: MetaDict - - -class AccStatsTwoFeatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.acc_stats_two_feats_func`""" - - log_path: str - dictionaries: List[str] - ali_paths: Dict[str, str] - acc_paths: Dict[str, str] - model_path: str - feature_strings: Dict[str, str] - si_feature_strings: Dict[str, str] - - -class LdaAccStatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.lda_acc_stats_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ali_paths: Dict[str, str] - model_path: str - lda_options: MetaDict - acc_paths: Dict[str, str] - - -class CalcLdaMlltArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.calc_lda_mllt_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ali_paths: Dict[str, str] - model_path: str - lda_options: MetaDict - macc_paths: Dict[str, str] - - -class MapAccStatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.map_acc_stats_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - model_path: str - ali_paths: Dict[str, str] - acc_paths: Dict[str, str] - - -class GmmGselectArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.gmm_gselect_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ivector_options: MetaDict - dubm_model: str - gselect_paths: Dict[str, str] - - -class AccGlobalStatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.acc_global_stats_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ivector_options: MetaDict - gselect_paths: Dict[str, str] - acc_paths: Dict[str, str] - dubm_path: str - - -class GaussToPostArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.gauss_to_post_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ivector_options: MetaDict - post_paths: Dict[str, str] - dubm_path: str - - -class AccIvectorStatsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.acc_ivector_stats_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ivector_options: MetaDict - ie_path: str - post_paths: Dict[str, str] - acc_init_paths: Dict[str, str] - - -class ExtractIvectorsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.extract_ivectors_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - ivector_options: MetaDict - ali_paths: Dict[str, str] - ie_path: str - ivector_paths: Dict[str, str] - weight_paths: Dict[str, str] - model_path: str - dubm_path: str - - -class CompileUtteranceTrainGraphsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compile_utterance_train_graphs_func`""" - - log_path: str - dictionaries: List[str] - disambig_int_paths: Dict[str, str] - disambig_L_fst_paths: Dict[str, str] - fst_paths: Dict[str, str] - graphs_paths: Dict[str, str] - model_path: str - tree_path: str - - -class TestUtterancesArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.test_utterances_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - words_paths: Dict[str, str] - graphs_paths: Dict[str, str] - text_int_paths: Dict[str, str] - edits_paths: Dict[str, str] - out_int_paths: Dict[str, str] - model_path: str - - -class SegmentVadArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.segment_vad_func`""" - - dictionaries: List[str] - vad_paths: Dict[str, str] - segmentation_options: MetaDict - - -class GeneratePronunciationsArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.pronunciations.generate_pronunciations_func`""" - - log_path: str - dictionaries: List[str] - text_int_paths: Dict[str, str] - word_boundary_paths: Dict[str, str] - ali_paths: Dict[str, str] - model_path: str - pron_paths: Dict[str, str] - - -class CreateHclgArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.create_hclg_func`""" - - log_path: str - working_directory: str - path_template: str - words_path: str - carpa_path: str - small_arpa_path: str - medium_arpa_path: str - big_arpa_path: str - model_path: str - disambig_L_path: str - disambig_int_path: str - hclg_options: MetaDict - words_mapping: MappingType - - @property - def hclg_path(self) -> str: - return self.path_template.format(file_name="HCLG") - - -class DecodeArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.decode_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - decode_options: MetaDict - model_path: str - lat_paths: Dict[str, str] - words_paths: Dict[str, str] - hclg_paths: Dict[str, str] - - -class ScoreArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.score_func`""" - - log_path: str - dictionaries: List[str] - score_options: MetaDict - lat_paths: Dict[str, str] - rescored_lat_paths: Dict[str, str] - carpa_rescored_lat_paths: Dict[str, str] - words_paths: Dict[str, str] - tra_paths: Dict[str, str] - - -class LmRescoreArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.lm_rescore_func`""" - - log_path: str - dictionaries: List[str] - lm_rescore_options: MetaDict - lat_paths: Dict[str, str] - rescored_lat_paths: Dict[str, str] - old_g_paths: Dict[str, str] - new_g_paths: Dict[str, str] - - -class CarpaLmRescoreArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.carpa_lm_rescore_func`""" - - log_path: str - dictionaries: List[str] - lat_paths: Dict[str, str] - rescored_lat_paths: Dict[str, str] - old_g_paths: Dict[str, str] - new_g_paths: Dict[str, str] - - -class InitialFmllrArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.initial_fmllr_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - model_path: str - fmllr_options: MetaDict - pre_trans_paths: Dict[str, str] - lat_paths: Dict[str, str] - spk2utt_paths: Dict[str, str] - - -class LatGenFmllrArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.lat_gen_fmllr_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - model_path: str - decode_options: MetaDict - words_paths: Dict[str, str] - hclg_paths: Dict[str, str] - tmp_lat_paths: Dict[str, str] - - -class FinalFmllrArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.final_fmllr_est_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - model_path: str - fmllr_options: MetaDict - trans_paths: Dict[str, str] - spk2utt_paths: Dict[str, str] - tmp_lat_paths: Dict[str, str] - - -class FmllrRescoreArguments(NamedTuple): - """Arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.fmllr_rescore_func`""" - - log_path: str - dictionaries: List[str] - feature_strings: Dict[str, str] - model_path: str - fmllr_options: MetaDict - tmp_lat_paths: Dict[str, str] - final_lat_paths: Dict[str, str] - - -class Job: - """ - Class representing information about corpus jobs that will be run in parallel. - Jobs have a set of speakers that they will process, along with all files and utterances associated with that speaker. - As such, Jobs also have a set of dictionaries that the speakers use, and argument outputs are largely dependent on - the pronunciation dictionaries in use. - - Parameters - ---------- - name: int - Job number is the job's identifier - - Attributes - ---------- - speakers: List[:class:`~montreal_forced_aligner.corpus.Speaker`] - List of speakers associated with this job - dictionaries: Set[:class:`~montreal_forced_aligner.dictionary.PronunciationDictionary`] - Set of dictionaries that the job's speakers use - subset_utts: Set[:class:`~montreal_forced_aligner.corpus.Utterance`] - When trainers are just using a subset of the corpus, the subset of utterances on each job will be set and used to - filter the job's utterances - subset_speakers: Set[:class:`~montreal_forced_aligner.corpus.Speaker`] - When subset_utts is set, this property will be calculated as the subset of speakers that the utterances correspond to - subset_dictionaries: Set[:class:`~montreal_forced_aligner.dictionary.PronunciationDictionary`] - Subset of dictionaries that the subset of speakers use - - """ - - def __init__(self, name: int): - self.name = name - self.speakers: List[Speaker] = [] - self.dictionaries = set() - - self.subset_utts = set() - self.subset_speakers = set() - self.subset_dictionaries = set() - - def add_speaker(self, speaker: Speaker) -> None: - """ - Add a speaker to a job - - Parameters - ---------- - speaker: :class:`~montreal_forced_aligner.corpus.Speaker` - Speaker to add - """ - self.speakers.append(speaker) - self.dictionaries.add(speaker.dictionary) - - def set_subset(self, subset_utts: Optional[Collection[Utterance]]) -> None: - """ - Set the current subset for the trainer - - Parameters - ---------- - subset_utts: Collection[:class:`~montreal_forced_aligner.corpus.Utterance`], optional - Subset of utterances for this job to use - """ - if subset_utts is None: - self.subset_utts = set() - self.subset_speakers = set() - self.subset_dictionaries = set() - else: - self.subset_utts = set(subset_utts) - self.subset_speakers = {u.speaker for u in subset_utts if u.speaker in self.speakers} - self.subset_dictionaries = {s.dictionary for s in self.subset_speakers} - - def text_scp_data(self) -> Dict[str, Dict[str, List[str]]]: - """ - Generate the job's data for Kaldi's text scp files - - Returns - ------- - Dict[str, Dict[str, List[str]]] - Text for each utterance, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - for u in s.utterances.values(): - if u.ignored: - continue - if self.subset_utts and u not in self.subset_utts: - continue - if not u.text: - continue - data[key][u.name] = u.text_for_scp() - return data - - def text_int_scp_data(self) -> Dict[str, Dict[str, str]]: - """ - Generate the job's data for Kaldi's text int scp files - - Returns - ------- - Dict[str, Dict[str, str]] - Text converted to integer IDs for each utterance, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - continue - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - for u in s.utterances.values(): - if self.subset_utts and u not in self.subset_utts: - continue - if u.ignored: - continue - if not u.text: - continue - data[key][u.name] = " ".join(map(str, u.text_int_for_scp())) - return data - - def wav_scp_data(self) -> Dict[str, Dict[str, str]]: - """ - Generate the job's data for Kaldi's wav scp files - - Returns - ------- - Dict[str, Dict[str, str]] - Wav scp strings for each file, per dictionary name - """ - data = {} - done = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - done[key] = set() - if self.subset_speakers and s not in self.subset_speakers: - continue - for u in s.utterances.values(): - if u.ignored: - continue - if self.subset_utts and u not in self.subset_utts: - continue - if not u.is_segment: - data[key][u.name] = u.file.for_wav_scp() - elif u.file.name not in done: - data[key][u.file.name] = u.file.for_wav_scp() - done[key].add(u.file.name) - return data - - def utt2spk_scp_data(self) -> Dict[str, Dict[str, str]]: - """ - Generate the job's data for Kaldi's utt2spk scp files - - Returns - ------- - Dict[str, Dict[str, str]] - Utterance to speaker mapping, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - for u in s.utterances.values(): - if u.ignored: - continue - if self.subset_utts and u not in self.subset_utts: - continue - data[key][u.name] = s.name - return data - - def feat_scp_data(self) -> Dict[str, Dict[str, str]]: - """ - Generate the job's data for Kaldi's feature scp files - - Returns - ------- - Dict[str, Dict[str, str]] - Utterance to feature archive ID mapping, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - for u in s.utterances.values(): - if u.ignored: - continue - if self.subset_utts and u not in self.subset_utts: - continue - if u.features: - data[key][u.name] = u.features - return data - - def spk2utt_scp_data(self) -> Dict[str, Dict[str, List[str]]]: - """ - Generate the job's data for Kaldi's spk2utt scp files - - Returns - ------- - Dict[str, Dict[str, List[str]]] - Speaker to utterance mapping, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - data[key][s.name] = sorted( - [ - u.name - for u in s.utterances.values() - if not u.ignored and not (self.subset_utts and u not in self.subset_utts) - ] - ) - return data - - def cmvn_scp_data(self) -> Dict[str, Dict[str, str]]: - """ - Generate the job's data for Kaldi's CMVN scp files - - Returns - ------- - Dict[str, Dict[str, str]] - Speaker to CMVN mapping, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - if s.cmvn: - data[key][s.name] = s.cmvn - return data - - def segments_scp_data(self) -> Dict[str, Dict[str, str]]: - """ - Generate the job's data for Kaldi's segments scp files - - Returns - ------- - Dict[str, Dict[str, str]] - Utterance to segment mapping, per dictionary name - """ - data = {} - for s in self.speakers: - if s.dictionary is None: - key = None - else: - key = s.dictionary.name - if key not in data: - data[key] = {} - if self.subset_speakers and s not in self.subset_speakers: - continue - for u in s.utterances.values(): - if u.ignored: - continue - if self.subset_utts and u not in self.subset_utts: - continue - if not u.is_segment: - continue - data[key][u.name] = u.segment_for_scp() - return data - - def construct_path_dictionary( - self, directory: str, identifier: str, extension: str - ) -> Dict[str, str]: - """ - Helper function for constructing dictionary-dependent paths for the Job - - Parameters - ---------- - directory: str - Directory to use as the root - identifier: str - Identifier for the path name, like ali or acc - extension: str - Extension of the path, like .scp or .ark - - Returns - ------- - Dict[str, str] - Path for each dictionary - """ - output = {} - for dict_name in self.current_dictionary_names: - output[dict_name] = os.path.join( - directory, f"{identifier}.{dict_name}.{self.name}.{extension}" - ) - return output - - def construct_dictionary_dependent_paths( - self, directory: str, identifier: str, extension: str - ) -> Dict[str, str]: - """ - Helper function for constructing paths that depend only on the dictionaries of the job, and not the job name itself. - These paths should be merged with all other jobs to get a full set of dictionary paths. - - Parameters - ---------- - directory: str - Directory to use as the root - identifier: str - Identifier for the path name, like ali or acc - extension: str - Extension of the path, like .scp or .ark - - Returns - ------- - Dict[str, str] - Path for each dictionary - """ - output = {} - for dict_name in self.current_dictionary_names: - output[dict_name] = os.path.join(directory, f"{identifier}.{dict_name}.{extension}") - return output - - @property - def dictionary_count(self): - """Number of dictionaries currently used""" - if self.subset_dictionaries: - return len(self.subset_dictionaries) - return len(self.dictionaries) - - @property - def current_dictionaries(self): - """Current dictionaries depending on whether a subset is being used""" - if self.subset_dictionaries: - return self.subset_dictionaries - return self.dictionaries - - @property - def current_dictionary_names(self): - """Current dictionary names depending on whether a subset is being used""" - if self.subset_dictionaries: - return sorted(x.name for x in self.subset_dictionaries) - if self.dictionaries == {None}: - return [None] - return sorted(x.name for x in self.dictionaries) - - def set_feature_config(self, feature_config: FeatureConfig) -> None: - """ - Set the feature configuration to use for the Job - - Parameters - ---------- - feature_config: :class:`~montreal_forced_aligner.config.FeatureConfig` - Feature configuration - """ - self.feature_config = feature_config - - def construct_base_feature_string(self, corpus: Corpus, all_feats: bool = False) -> str: - """ - Construct the base feature string independent of job name - - Used in initialization of MonophoneTrainer (to get dimension size) and IvectorTrainer (uses all feats) - - Parameters - ---------- - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use as the source - all_feats: bool - Flag for whether all features across all jobs should be taken into account - - Returns - ------- - str - Feature string - """ - if all_feats: - feat_path = os.path.join(corpus.output_directory, "feats.scp") - utt2spk_path = os.path.join(corpus.output_directory, "utt2spk.scp") - cmvn_path = os.path.join(corpus.output_directory, "cmvn.scp") - feats = f"ark,s,cs:apply-cmvn --utt2spk=ark:{utt2spk_path} scp:{cmvn_path} scp:{feat_path} ark:- |" - feats += " add-deltas ark:- ark:- |" - return feats - utt2spks = self.construct_path_dictionary(corpus.split_directory, "utt2spk", "scp") - cmvns = self.construct_path_dictionary(corpus.split_directory, "cmvn", "scp") - features = self.construct_path_dictionary(corpus.split_directory, "feats", "scp") - for dict_name in self.current_dictionary_names: - feat_path = features[dict_name] - cmvn_path = cmvns[dict_name] - utt2spk_path = utt2spks[dict_name] - feats = f"ark,s,cs:apply-cmvn --utt2spk=ark:{utt2spk_path} scp:{cmvn_path} scp:{feat_path} ark:- |" - if self.feature_config.deltas: - feats += " add-deltas ark:- ark:- |" - - return feats - - def construct_feature_proc_strings( - self, - aligner: MfaWorker, - speaker_independent: bool = False, - ) -> Dict[str, str]: - """ - Constructs a feature processing string to supply to Kaldi binaries, taking into account corpus features and the - current working directory of the aligner (whether fMLLR or LDA transforms should be used, etc). - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.abc.MfaWorker` - Aligner, Transcriber or other main utility class that uses the features - speaker_independent: bool - Flag for whether features should be speaker-independent regardless of the presence of fMLLR transforms - - Returns - ------- - Dict[str, str] - Feature strings per dictionary name - """ - lda_mat_path = None - fmllrs = {} - if aligner.working_directory is not None: - lda_mat_path = os.path.join(aligner.working_directory, "lda.mat") - if not os.path.exists(lda_mat_path): - lda_mat_path = None - - fmllrs = self.construct_path_dictionary(aligner.working_directory, "trans", "ark") - utt2spks = self.construct_path_dictionary(aligner.data_directory, "utt2spk", "scp") - cmvns = self.construct_path_dictionary(aligner.data_directory, "cmvn", "scp") - features = self.construct_path_dictionary(aligner.data_directory, "feats", "scp") - vads = self.construct_path_dictionary(aligner.data_directory, "vad", "scp") - feat_strings = {} - for dict_name in self.current_dictionary_names: - feat_path = features[dict_name] - cmvn_path = cmvns[dict_name] - utt2spk_path = utt2spks[dict_name] - fmllr_trans_path = None - try: - fmllr_trans_path = fmllrs[dict_name] - if not os.path.exists(fmllr_trans_path): - fmllr_trans_path = None - except KeyError: - pass - vad_path = vads[dict_name] - if aligner.uses_voiced: - feats = f"ark,s,cs:add-deltas scp:{feat_path} ark:- |" - if aligner.uses_cmvn: - feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |" - feats += f" select-voiced-frames ark:- scp,s,cs:{vad_path} ark:- |" - elif not os.path.exists(cmvn_path) and aligner.uses_cmvn: - feats = f"ark,s,cs:add-deltas scp:{feat_path} ark:- |" - if aligner.uses_cmvn: - feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |" - else: - feats = f"ark,s,cs:apply-cmvn --utt2spk=ark:{utt2spk_path} scp:{cmvn_path} scp:{feat_path} ark:- |" - if lda_mat_path is not None: - if not os.path.exists(lda_mat_path): - raise Exception(f"Could not find {lda_mat_path}") - feats += f" splice-feats --left-context={self.feature_config.splice_left_context} --right-context={self.feature_config.splice_right_context} ark:- ark:- |" - feats += f" transform-feats {lda_mat_path} ark:- ark:- |" - elif aligner.uses_splices: - feats += f" splice-feats --left-context={self.feature_config.splice_left_context} --right-context={self.feature_config.splice_right_context} ark:- ark:- |" - elif self.feature_config.deltas: - feats += " add-deltas ark:- ark:- |" - - if fmllr_trans_path is not None and not ( - aligner.speaker_independent or speaker_independent - ): - if not os.path.exists(fmllr_trans_path): - raise Exception(f"Could not find {fmllr_trans_path}") - feats += f" transform-feats --utt2spk=ark:{utt2spk_path} ark:{fmllr_trans_path} ark:- ark:- |" - feat_strings[dict_name] = feats - return feat_strings - - def compile_utterance_train_graphs_arguments( - self, validator: CorpusValidator - ) -> CompileUtteranceTrainGraphsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compile_utterance_train_graphs_func` - - Parameters - ---------- - validator: :class:`~montreal_forced_aligner.validator.CorpusValidator` - Validator - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CompileUtteranceTrainGraphsArguments` - Arguments for processing - """ - dictionary_paths = validator.dictionary.output_paths - disambig_paths = { - k: os.path.join(v, "phones", "disambiguation_symbols.int") - for k, v in dictionary_paths.items() - } - lexicon_fst_paths = { - k: os.path.join(v, "L_disambig.fst") for k, v in dictionary_paths.items() - } - return CompileUtteranceTrainGraphsArguments( - os.path.join( - validator.trainer.working_log_directory, f"utterance_fst.{self.name}.log" - ), - self.current_dictionary_names, - disambig_paths, - lexicon_fst_paths, - self.construct_path_dictionary(validator.trainer.data_directory, "utt2fst", "scp"), - self.construct_path_dictionary( - validator.trainer.working_directory, "utterance_graphs", "fst" - ), - validator.trainer.current_model_path, - validator.trainer.tree_path, - ) - - def test_utterances_arguments(self, validator: CorpusValidator) -> TestUtterancesArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.test_utterances_func` - - Parameters - ---------- - validator: :class:`~montreal_forced_aligner.validator.CorpusValidator` - Validator - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.TestUtterancesArguments` - Arguments for processing - """ - dictionary_paths = validator.dictionary.output_paths - words_paths = {k: os.path.join(v, "words.txt") for k, v in dictionary_paths.items()} - return TestUtterancesArguments( - os.path.join(validator.trainer.working_directory, f"utterance_fst.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(validator.trainer), - words_paths, - self.construct_path_dictionary( - validator.trainer.working_directory, "utterance_graphs", "fst" - ), - self.construct_path_dictionary(validator.trainer.data_directory, "text", "int.scp"), - self.construct_path_dictionary(validator.trainer.working_directory, "edits", "scp"), - self.construct_path_dictionary(validator.trainer.working_directory, "aligned", "int"), - validator.trainer.current_model_path, - ) - - def extract_ivector_arguments( - self, ivector_extractor: IvectorExtractor - ) -> ExtractIvectorsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.extract_ivectors_func` - - Parameters - ---------- - ivector_extractor: :class:`~montreal_forced_aligner.abc.IvectorExtractor` - Ivector extractor - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.ExtractIvectorsArguments` - Arguments for processing - """ - return ExtractIvectorsArguments( - os.path.join( - ivector_extractor.working_log_directory, f"extract_ivectors.{self.name}.log" - ), - self.current_dictionary_names, - self.construct_feature_proc_strings(ivector_extractor), - ivector_extractor.ivector_options, - self.construct_path_dictionary(ivector_extractor.working_directory, "ali", "ark"), - ivector_extractor.ie_path, - self.construct_path_dictionary(ivector_extractor.working_directory, "ivectors", "scp"), - self.construct_path_dictionary(ivector_extractor.working_directory, "weights", "ark"), - ivector_extractor.model_path, - ivector_extractor.dubm_path, - ) - - def create_hclgs_arguments(self, transcriber: Transcriber) -> Dict[str, CreateHclgArguments]: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.create_hclg_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - Dict[str, :class:`~montreal_forced_aligner.multiprocessing.classes.CreateHclgArguments`] - Per dictionary arguments for HCLG - """ - args = {} - - for dictionary in self.current_dictionaries: - dict_name = dictionary.name - args[dict_name] = CreateHclgArguments( - os.path.join(transcriber.model_directory, "log", f"hclg.{dict_name}.log"), - transcriber.model_directory, - os.path.join(transcriber.model_directory, "{file_name}" + f".{dict_name}.fst"), - os.path.join(transcriber.model_directory, f"words.{dict_name}.txt"), - os.path.join(transcriber.model_directory, f"G.{dict_name}.carpa"), - transcriber.language_model.small_arpa_path, - transcriber.language_model.medium_arpa_path, - transcriber.language_model.carpa_path, - transcriber.model_path, - dictionary.disambig_path, - os.path.join(dictionary.phones_dir, "disambiguation_symbols.int"), - transcriber.hclg_options, - dictionary.words_mapping, - ) - return args - - def decode_arguments(self, transcriber: Transcriber) -> DecodeArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.decode_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.DecodeArguments` - Arguments for processing - """ - return DecodeArguments( - os.path.join(transcriber.working_log_directory, f"decode.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(transcriber), - transcriber.transcribe_config.decode_options, - transcriber.alignment_model_path, - self.construct_path_dictionary(transcriber.working_directory, "lat", "ark"), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "words", "txt"), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "HCLG", "fst"), - ) - - def score_arguments(self, transcriber: Transcriber) -> ScoreArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.score_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.ScoreArguments` - Arguments for processing - """ - return ScoreArguments( - os.path.join(transcriber.working_log_directory, f"score.{self.name}.log"), - self.current_dictionary_names, - transcriber.transcribe_config.score_options, - self.construct_path_dictionary(transcriber.working_directory, "lat", "ark"), - self.construct_path_dictionary(transcriber.working_directory, "lat.rescored", "ark"), - self.construct_path_dictionary( - transcriber.working_directory, "lat.carpa.rescored", "ark" - ), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "words", "txt"), - self.construct_path_dictionary(transcriber.evaluation_directory, "tra", "scp"), - ) - - def lm_rescore_arguments(self, transcriber: Transcriber) -> LmRescoreArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.lm_rescore_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.LmRescoreArguments` - Arguments for processing - """ - return LmRescoreArguments( - os.path.join(transcriber.working_log_directory, f"lm_rescore.{self.name}.log"), - self.current_dictionary_names, - transcriber.transcribe_config.lm_rescore_options, - self.construct_path_dictionary(transcriber.working_directory, "lat", "ark"), - self.construct_path_dictionary(transcriber.working_directory, "lat.rescored", "ark"), - self.construct_dictionary_dependent_paths( - transcriber.model_directory, "G.small", "fst" - ), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "G.med", "fst"), - ) - - def carpa_lm_rescore_arguments(self, transcriber: Transcriber) -> CarpaLmRescoreArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.carpa_lm_rescore_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CarpaLmRescoreArguments` - Arguments for processing - """ - return CarpaLmRescoreArguments( - os.path.join(transcriber.working_log_directory, f"carpa_lm_rescore.{self.name}.log"), - self.current_dictionary_names, - self.construct_path_dictionary(transcriber.working_directory, "lat.rescored", "ark"), - self.construct_path_dictionary( - transcriber.working_directory, "lat.carpa.rescored", "ark" - ), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "G.med", "fst"), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "G", "carpa"), - ) - - def initial_fmllr_arguments(self, transcriber: Transcriber) -> InitialFmllrArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.initial_fmllr_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.InitialFmllrArguments` - Arguments for processing - """ - return InitialFmllrArguments( - os.path.join(transcriber.working_log_directory, f"initial_fmllr.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(transcriber), - transcriber.model_path, - transcriber.fmllr_options, - self.construct_path_dictionary(transcriber.working_directory, "trans", "ark"), - self.construct_path_dictionary(transcriber.working_directory, "lat", "ark"), - self.construct_path_dictionary(transcriber.data_directory, "spk2utt", "scp"), - ) - - def lat_gen_fmllr_arguments(self, transcriber: Transcriber) -> LatGenFmllrArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.lat_gen_fmllr_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.LatGenFmllrArguments` - Arguments for processing - """ - return LatGenFmllrArguments( - os.path.join(transcriber.working_log_directory, f"lat_gen_fmllr.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(transcriber), - transcriber.model_path, - transcriber.transcribe_config.decode_options, - self.construct_dictionary_dependent_paths(transcriber.model_directory, "words", "txt"), - self.construct_dictionary_dependent_paths(transcriber.model_directory, "HCLG", "fst"), - self.construct_path_dictionary(transcriber.working_directory, "lat.tmp", "ark"), - ) - - def final_fmllr_arguments(self, transcriber: Transcriber) -> FinalFmllrArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.final_fmllr_est_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.FinalFmllrArguments` - Arguments for processing - """ - return FinalFmllrArguments( - os.path.join(transcriber.working_log_directory, f"final_fmllr.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(transcriber), - transcriber.model_path, - transcriber.fmllr_options, - self.construct_path_dictionary(transcriber.working_directory, "trans", "ark"), - self.construct_path_dictionary(transcriber.data_directory, "spk2utt", "scp"), - self.construct_path_dictionary(transcriber.working_directory, "lat.tmp", "ark"), - ) - - def fmllr_rescore_arguments(self, transcriber: Transcriber) -> FmllrRescoreArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.transcription.fmllr_rescore_func` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.FmllrRescoreArguments` - Arguments for processing - """ - return FmllrRescoreArguments( - os.path.join(transcriber.working_log_directory, f"fmllr_rescore.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(transcriber), - transcriber.model_path, - transcriber.fmllr_options, - self.construct_path_dictionary(transcriber.working_directory, "lat.tmp", "ark"), - self.construct_path_dictionary(transcriber.working_directory, "lat", "ark"), - ) - - def vad_arguments(self, corpus: Corpus) -> VadArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.features.compute_vad_func` - - Parameters - ---------- - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.VadArguments` - Arguments for processing - """ - return VadArguments( - os.path.join(corpus.split_directory, "log", f"compute_vad.{self.name}.log"), - self.current_dictionary_names, - self.construct_path_dictionary(corpus.split_directory, "feats", "scp"), - self.construct_path_dictionary(corpus.split_directory, "vad", "scp"), - corpus.vad_config, - ) - - def segments_vad_arguments(self, segmenter: Segmenter) -> SegmentVadArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.segment_vad_func` - - Parameters - ---------- - segmenter: :class:`~montreal_forced_aligner.segmenter.Segmenter` - Segmenter - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.SegmentVadArguments` - Arguments for processing - """ - return SegmentVadArguments( - self.current_dictionary_names, - self.construct_path_dictionary(segmenter.corpus.split_directory, "vad", "scp"), - segmenter.segmentation_config.segmentation_options, - ) - - def mfcc_arguments(self, corpus: Corpus) -> MfccArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.features.mfcc_func` - - Parameters - ---------- - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.MfccArguments` - Arguments for processing - """ - return MfccArguments( - os.path.join(corpus.split_directory, "log", f"make_mfcc.{self.name}.log"), - self.current_dictionary_names, - self.construct_path_dictionary(corpus.split_directory, "feats", "scp"), - self.construct_path_dictionary(corpus.split_directory, "utterance_lengths", "scp"), - self.construct_path_dictionary(corpus.split_directory, "segments", "scp"), - self.construct_path_dictionary(corpus.split_directory, "wav", "scp"), - self.feature_config.mfcc_options, - ) - - def acc_stats_arguments(self, aligner: BaseTrainer) -> AccStatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.acc_stats_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.BaseTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AccStatsArguments` - Arguments for processing - """ - return AccStatsArguments( - os.path.join( - aligner.working_directory, "log", f"acc.{aligner.iteration}.{self.name}.log" - ), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.construct_path_dictionary( - aligner.working_directory, str(aligner.iteration), "acc" - ), - aligner.current_model_path, - ) - - def mono_align_equal_arguments(self, aligner: MonophoneTrainer) -> MonoAlignEqualArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.mono_align_equal_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.MonophoneTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.MonoAlignEqualArguments` - Arguments for processing - """ - return MonoAlignEqualArguments( - os.path.join(aligner.working_log_directory, f"mono_align_equal.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - self.construct_path_dictionary(aligner.working_directory, "fsts", "scp"), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.working_directory, "0", "acc"), - aligner.current_model_path, - ) - - def align_arguments(self, aligner: Aligner) -> AlignArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.align_func` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AlignArguments` - Arguments for processing - """ - if aligner.iteration is not None: - log_path = os.path.join( - aligner.working_log_directory, f"align.{aligner.iteration}.{self.name}.log" - ) - else: - log_path = os.path.join(aligner.working_log_directory, f"align.{self.name}.log") - return AlignArguments( - log_path, - self.current_dictionary_names, - self.construct_path_dictionary(aligner.working_directory, "fsts", "scp"), - self.construct_feature_proc_strings(aligner), - aligner.alignment_model_path, - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.working_directory, "ali", "scores"), - self.construct_path_dictionary(aligner.working_directory, "ali", "loglikes"), - aligner.align_options, - ) - - def compile_information_arguments(self, aligner: BaseTrainer) -> CompileInformationArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compile_information_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.BaseTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CompileInformationArguments` - Arguments for processing - """ - if aligner.iteration is not None: - log_path = os.path.join( - aligner.working_log_directory, f"align.{aligner.iteration}.{self.name}.log" - ) - else: - log_path = os.path.join(aligner.working_log_directory, f"align.{self.name}.log") - return CompileInformationArguments(log_path) - - def word_boundary_int_files(self) -> Dict[str, str]: - """ - Generate mapping for dictionaries to word boundary int files - - Returns - ------- - Dict[str, ReversedMappingType] - Per dictionary word boundary int files - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = os.path.join(dictionary.phones_dir, "word_boundary.int") - return data - - def reversed_phone_mappings(self) -> Dict[str, ReversedMappingType]: - """ - Generate mapping for dictionaries to reversed phone mapping - - Returns - ------- - Dict[str, ReversedMappingType] - Per dictionary reversed phone mapping - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.reversed_phone_mapping - return data - - def reversed_word_mappings(self) -> Dict[str, ReversedMappingType]: - """ - Generate mapping for dictionaries to reversed word mapping - - Returns - ------- - Dict[str, ReversedMappingType] - Per dictionary reversed word mapping - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.reversed_word_mapping - return data - - def words_mappings(self) -> Dict[str, MappingType]: - """ - Generate mapping for dictionaries to word mapping - - Returns - ------- - Dict[str, MappingType] - Per dictionary word mapping - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.words_mapping - return data - - def words(self) -> Dict[str, WordsType]: - """ - Generate mapping for dictionaries to words - - Returns - ------- - Dict[str, WordsType] - Per dictionary words - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.words - return data - - def punctuation(self): - """ - Generate mapping for dictionaries to punctuation - - Returns - ------- - Dict[str, str] - Per dictionary punctuation - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.punctuation - return data - - def clitic_set(self) -> Dict[str, Set[str]]: - """ - Generate mapping for dictionaries to clitic sets - - Returns - ------- - Dict[str, str] - Per dictionary clitic sets - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.clitic_set - return data - - def clitic_markers(self) -> Dict[str, str]: - """ - Generate mapping for dictionaries to clitic markers - - Returns - ------- - Dict[str, str] - Per dictionary clitic markers - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.clitic_markers - return data - - def compound_markers(self) -> Dict[str, str]: - """ - Generate mapping for dictionaries to compound markers - - Returns - ------- - Dict[str, str] - Per dictionary compound markers - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.compound_markers - return data - - def strip_diacritics(self) -> Dict[str, List[str]]: - """ - Generate mapping for dictionaries to diacritics to strip - - Returns - ------- - Dict[str, List[str]] - Per dictionary strip diacritics - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.strip_diacritics - return data - - def oov_codes(self) -> Dict[str, str]: - """ - Generate mapping for dictionaries to oov symbols - - Returns - ------- - Dict[str, str] - Per dictionary oov symbols - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.oov_code - return data - - def oov_ints(self) -> Dict[str, int]: - """ - Generate mapping for dictionaries to oov ints - - Returns - ------- - Dict[str, int] - Per dictionary oov ints - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.oov_int - return data - - def positions(self) -> Dict[str, List[str]]: - """ - Generate mapping for dictionaries to positions - - Returns - ------- - Dict[str, List[str]] - Per dictionary positions - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.positions - return data - - def silences(self) -> Dict[str, Set[str]]: - """ - Generate mapping for dictionaries to silence symbols - - Returns - ------- - Dict[str, Set[str]] - Per dictionary silence symbols - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.silences - return data - - def multilingual_ipa(self) -> Dict[str, bool]: - """ - Generate mapping for dictionaries to multilingual IPA flags - - Returns - ------- - Dict[str, bool] - Per dictionary multilingual IPA flags - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.multilingual_ipa - return data - - def generate_pronunciations_arguments( - self, aligner: Aligner - ) -> GeneratePronunciationsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.pronunciations.generate_pronunciations_func` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.GeneratePronunciationsArguments` - Arguments for processing - """ - return GeneratePronunciationsArguments( - os.path.join( - aligner.working_log_directory, f"generate_pronunciations.{self.name}.log" - ), - self.current_dictionary_names, - self.construct_path_dictionary(aligner.data_directory, "text", "int.scp"), - self.word_boundary_int_files(), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - aligner.model_path, - self.construct_path_dictionary(aligner.working_directory, "prons", "scp"), - ) - - def alignment_improvement_arguments( - self, aligner: BaseTrainer - ) -> AlignmentImprovementArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compute_alignment_improvement_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.BaseTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AlignmentImprovementArguments` - Arguments for processing - """ - return AlignmentImprovementArguments( - os.path.join(aligner.working_log_directory, f"alignment_analysis.{self.name}.log"), - self.current_dictionary_names, - aligner.current_model_path, - self.construct_path_dictionary(aligner.data_directory, "text", "int.scp"), - self.word_boundary_int_files(), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.feature_config.frame_shift, - self.reversed_phone_mappings(), - self.positions(), - self.construct_path_dictionary( - aligner.working_directory, f"phone.{aligner.iteration}", "ctm" - ), - ) - - def ali_to_word_ctm_arguments(self, aligner: BaseAligner) -> AliToCtmArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.ali_to_ctm_func` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AliToCtmArguments` - Arguments for processing - """ - return AliToCtmArguments( - os.path.join(aligner.working_log_directory, f"get_word_ctm.{self.name}.log"), - self.current_dictionary_names, - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.data_directory, "text", "int.scp"), - self.word_boundary_int_files(), - round(self.feature_config.frame_shift / 1000, 4), - aligner.alignment_model_path, - self.construct_path_dictionary(aligner.working_directory, "word", "ctm"), - True, - ) - - def ali_to_phone_ctm_arguments(self, aligner: Aligner) -> AliToCtmArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.ali_to_ctm_func` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AliToCtmArguments` - Arguments for processing - """ - return AliToCtmArguments( - os.path.join(aligner.working_log_directory, f"get_phone_ctm.{self.name}.log"), - self.current_dictionary_names, - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.data_directory, "text", "int.scp"), - self.word_boundary_int_files(), - round(self.feature_config.frame_shift / 1000, 4), - aligner.alignment_model_path, - self.construct_path_dictionary(aligner.working_directory, "phone", "ctm"), - False, - ) - - def job_utts(self) -> Dict[str, Dict[str, Utterance]]: - """ - Generate utterances by dictionary name for the Job - - Returns - ------- - Dict[str, Dict[str, :class:`~montreal_forced_aligner.corpus.Utterance`]] - Mapping of dictionary name to Utterance mappings - """ - data = {} - speakers = self.subset_speakers - if not speakers: - speakers = self.speakers - for s in speakers: - if s.dictionary.name not in data: - data[s.dictionary.name] = {} - data[s.dictionary.name].update(s.utterances) - return data - - def job_files(self) -> Dict[str, File]: - """ - Generate files for the Job - - Returns - ------- - Dict[str, :class:`~montreal_forced_aligner.corpus.File`] - Mapping of file name to File objects - """ - data = {} - speakers = self.subset_speakers - if not speakers: - speakers = self.speakers - for s in speakers: - for f in s.files: - for sf in f.speaker_ordering: - if sf.name == s.name: - sf.dictionary_data = s.dictionary_data - data[f.name] = f - return data - - def cleanup_word_ctm_arguments(self, aligner: Aligner) -> CleanupWordCtmArguments: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.multiprocessing.CleanupWordCtmProcessWorker` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CleanupWordCtmArguments` - Arguments for processing - """ - return CleanupWordCtmArguments( - self.construct_path_dictionary(aligner.align_directory, "word", "ctm"), - self.current_dictionary_names, - self.job_utts(), - self.dictionary_data(), - ) - - def no_cleanup_word_ctm_arguments(self, aligner: Aligner) -> NoCleanupWordCtmArguments: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.multiprocessing.NoCleanupWordCtmProcessWorker` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.NoCleanupWordCtmArguments` - Arguments for processing - """ - return NoCleanupWordCtmArguments( - self.construct_path_dictionary(aligner.align_directory, "word", "ctm"), - self.current_dictionary_names, - self.job_utts(), - self.dictionary_data(), - ) - - def phone_ctm_arguments(self, aligner: Aligner) -> PhoneCtmArguments: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.multiprocessing.PhoneCtmProcessWorker` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.PhoneCtmArguments` - Arguments for processing - """ - return PhoneCtmArguments( - self.construct_path_dictionary(aligner.align_directory, "phone", "ctm"), - self.current_dictionary_names, - self.job_utts(), - self.reversed_phone_mappings(), - self.positions(), - ) - - def dictionary_data(self) -> Dict[str, DictionaryData]: - """ - Generate dictionary data for the job - - Returns - ------- - Dict[str, DictionaryData] - Mapping of dictionary name to dictionary data - """ - data = {} - for dictionary in self.current_dictionaries: - data[dictionary.name] = dictionary.data() - return data - - def combine_ctm_arguments(self, aligner: Aligner) -> CombineCtmArguments: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.multiprocessing.CombineProcessWorker` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CombineCtmArguments` - Arguments for processing - """ - return CombineCtmArguments( - self.current_dictionary_names, - self.job_files(), - self.dictionary_data(), - aligner.align_config.cleanup_textgrids, - ) - - def export_textgrid_arguments(self, aligner: Aligner) -> ExportTextGridArguments: - """ - Generate Job arguments for :class:`~montreal_forced_aligner.multiprocessing.ExportTextGridProcessWorker` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.ExportTextGridArguments` - Arguments for processing - """ - return ExportTextGridArguments( - aligner.corpus.files, - aligner.feature_config.frame_shift, - aligner.textgrid_output, - aligner.backup_output_directory, - ) - - def tree_stats_arguments(self, aligner: BaseTrainer) -> TreeStatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.tree_stats_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.BaseTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.TreeStatsArguments` - Arguments for processing - """ - return TreeStatsArguments( - os.path.join(aligner.working_log_directory, f"acc_tree.{self.name}.log"), - self.current_dictionary_names, - aligner.dictionary.config.silence_csl, - aligner.previous_trainer.alignment_model_path, - self.construct_feature_proc_strings(aligner), - self.construct_path_dictionary(aligner.previous_trainer.align_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.working_directory, "tree", "acc"), - ) - - def convert_alignment_arguments(self, aligner: BaseTrainer) -> ConvertAlignmentsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.convert_alignments_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.BaseTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.ConvertAlignmentsArguments` - Arguments for processing - """ - return ConvertAlignmentsArguments( - os.path.join(aligner.working_log_directory, f"convert_alignments.{self.name}.log"), - self.current_dictionary_names, - aligner.current_model_path, - aligner.tree_path, - aligner.previous_trainer.alignment_model_path, - self.construct_path_dictionary( - aligner.previous_trainer.working_directory, "ali", "ark" - ), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - ) - - def calc_fmllr_arguments(self, aligner: Aligner) -> CalcFmllrArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.calc_fmllr_func` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CalcFmllrArguments` - Arguments for processing - """ - return CalcFmllrArguments( - os.path.join(aligner.working_log_directory, f"calc_fmllr.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - aligner.alignment_model_path, - aligner.model_path, - self.construct_path_dictionary(aligner.data_directory, "spk2utt", "scp"), - self.construct_path_dictionary(aligner.working_directory, "trans", "ark"), - aligner.fmllr_options, - ) - - def acc_stats_two_feats_arguments(self, aligner: SatTrainer) -> AccStatsTwoFeatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.acc_stats_two_feats_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.SatTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AccStatsTwoFeatsArguments` - Arguments for processing - """ - return AccStatsTwoFeatsArguments( - os.path.join(aligner.working_log_directory, f"acc_stats_two_feats.{self.name}.log"), - self.current_dictionary_names, - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.working_directory, "two_feat_acc", "ark"), - aligner.current_model_path, - self.construct_feature_proc_strings(aligner), - self.construct_feature_proc_strings(aligner, speaker_independent=True), - ) - - def lda_acc_stats_arguments(self, aligner: LdaTrainer) -> LdaAccStatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.lda_acc_stats_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.LdaTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.LdaAccStatsArguments` - Arguments for processing - """ - return LdaAccStatsArguments( - os.path.join(aligner.working_log_directory, f"lda_acc_stats.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - self.construct_path_dictionary( - aligner.previous_trainer.working_directory, "ali", "ark" - ), - aligner.previous_trainer.alignment_model_path, - aligner.lda_options, - self.construct_path_dictionary(aligner.working_directory, "lda", "acc"), - ) - - def calc_lda_mllt_arguments(self, aligner: LdaTrainer) -> CalcLdaMlltArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.calc_lda_mllt_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.LdaTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CalcLdaMlltArguments` - Arguments for processing - """ - return CalcLdaMlltArguments( - os.path.join(aligner.working_log_directory, f"lda_mllt.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - self.construct_path_dictionary(aligner.working_directory, "ali", "ark"), - aligner.current_model_path, - aligner.lda_options, - self.construct_path_dictionary(aligner.working_directory, "lda", "macc"), - ) - - def ivector_acc_stats_arguments( - self, trainer: IvectorExtractorTrainer - ) -> AccIvectorStatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.acc_ivector_stats_func` - - Parameters - ---------- - trainer: :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AccIvectorStatsArguments` - Arguments for processing - """ - return AccIvectorStatsArguments( - os.path.join(trainer.working_log_directory, f"ivector_acc.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(trainer), - trainer.ivector_options, - trainer.current_ie_path, - self.construct_path_dictionary(trainer.working_directory, "post", "ark"), - self.construct_path_dictionary(trainer.working_directory, "ivector", "acc"), - ) - - def map_acc_stats_arguments(self, aligner: AdaptingAligner) -> MapAccStatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.map_acc_stats_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.aligner.AdaptingAligner` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.MapAccStatsArguments` - Arguments for processing - """ - return MapAccStatsArguments( - os.path.join(aligner.working_log_directory, f"map_acc_stats.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - aligner.current_model_path, - self.construct_path_dictionary(aligner.previous_aligner.align_directory, "ali", "ark"), - self.construct_path_dictionary(aligner.working_directory, "map", "acc"), - ) - - def gmm_gselect_arguments(self, aligner: IvectorExtractorTrainer) -> GmmGselectArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.gmm_gselect_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.GmmGselectArguments` - Arguments for processing - """ - return GmmGselectArguments( - os.path.join(aligner.working_log_directory, f"gmm_gselect.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - aligner.ivector_options, - aligner.current_dubm_path, - self.construct_path_dictionary(aligner.working_directory, "gselect", "ark"), - ) - - def acc_global_stats_arguments( - self, aligner: IvectorExtractorTrainer - ) -> AccGlobalStatsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.acc_global_stats_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligners.trainers.IvectorExtractorTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.AccGlobalStatsArguments` - Arguments for processing - """ - return AccGlobalStatsArguments( - os.path.join( - aligner.working_log_directory, - f"acc_global_stats.{aligner.iteration}.{self.name}.log", - ), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - aligner.ivector_options, - self.construct_path_dictionary(aligner.working_directory, "gselect", "ark"), - self.construct_path_dictionary( - aligner.working_directory, f"global.{aligner.iteration}", "acc" - ), - aligner.current_dubm_path, - ) - - def gauss_to_post_arguments(self, aligner: IvectorExtractorTrainer) -> GaussToPostArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.ivector.gauss_to_post_func` - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.GaussToPostArguments` - Arguments for processing - """ - return GaussToPostArguments( - os.path.join(aligner.working_log_directory, f"gauss_to_post.{self.name}.log"), - self.current_dictionary_names, - self.construct_feature_proc_strings(aligner), - aligner.ivector_options, - self.construct_path_dictionary(aligner.working_directory, "post", "ark"), - aligner.current_dubm_path, - ) - - def compile_train_graph_arguments(self, aligner: Aligner) -> CompileTrainGraphsArguments: - """ - Generate Job arguments for :func:`~montreal_forced_aligner.multiprocessing.alignment.compile_train_graphs_func` - - Parameters - ---------- - aligner: Aligner - Aligner - - Returns - ------- - :class:`~montreal_forced_aligner.multiprocessing.classes.CompileTrainGraphsArguments` - Arguments for processing - """ - dictionary_paths = aligner.dictionary.output_paths - disambig_paths = { - k: os.path.join(v, "phones", "disambiguation_symbols.int") - for k, v in dictionary_paths.items() - } - lexicon_fst_paths = {k: os.path.join(v, "L.fst") for k, v in dictionary_paths.items()} - model_path = aligner.model_path - if not os.path.exists(model_path): - model_path = aligner.alignment_model_path - return CompileTrainGraphsArguments( - os.path.join(aligner.working_log_directory, f"compile_train_graphs.{self.name}.log"), - self.current_dictionary_names, - os.path.join(aligner.working_directory, "tree"), - model_path, - self.construct_path_dictionary(aligner.data_directory, "text", "int.scp"), - disambig_paths, - lexicon_fst_paths, - self.construct_path_dictionary(aligner.working_directory, "fsts", "scp"), - ) - - def utt2fst_scp_data( - self, corpus: Corpus, num_frequent_words: int = 10 - ) -> Dict[str, List[Tuple[str, str]]]: - """ - Generate Kaldi style utt2fst scp data - - Parameters - ---------- - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to generate data for - num_frequent_words: int - Number of frequent words to include in the unigram language model - - Returns - ------- - Dict[str, List[Tuple[str, str]]] - Utterance FSTs per dictionary name - """ - data = {} - most_frequent = {} - for dict_name, utterances in self.job_utts().items(): - data[dict_name] = [] - for u_name, utterance in utterances.items(): - new_text = [] - dictionary = utterance.speaker.dictionary - if dictionary.name not in most_frequent: - word_frequencies = corpus.get_word_frequency() - most_frequent[dictionary.name] = sorted( - word_frequencies.items(), key=lambda x: -x[1] - )[:num_frequent_words] - - for t in utterance.text: - lookup = utterance.speaker.dictionary.split_clitics(t) - if lookup is None: - continue - new_text.extend(x for x in lookup if x != "") - data[dict_name].append( - ( - u_name, - dictionary.create_utterance_fst(new_text, most_frequent[dictionary.name]), - ) - ) - return data - - def output_utt_fsts(self, corpus: Corpus, num_frequent_words: int = 10) -> None: - """ - Write utterance FSTs - - Parameters - ---------- - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to generate FSTs for - num_frequent_words: int - Number of frequent words - """ - utt2fst = self.utt2fst_scp_data(corpus, num_frequent_words) - for dict_name, scp in utt2fst.items(): - utt2fst_scp_path = os.path.join( - corpus.split_directory, f"utt2fst.{dict_name}.{self.name}.scp" - ) - save_scp(scp, utt2fst_scp_path, multiline=True) - - def output_to_directory(self, split_directory: str) -> None: - """ - Output job information to a directory - - Parameters - ---------- - split_directory: str - Directory to output to - """ - wav = self.wav_scp_data() - for dict_name, scp in wav.items(): - wav_scp_path = os.path.join(split_directory, f"wav.{dict_name}.{self.name}.scp") - output_mapping(scp, wav_scp_path, skip_safe=True) - - spk2utt = self.spk2utt_scp_data() - for dict_name, scp in spk2utt.items(): - spk2utt_scp_path = os.path.join( - split_directory, f"spk2utt.{dict_name}.{self.name}.scp" - ) - output_mapping(scp, spk2utt_scp_path) - - feats = self.feat_scp_data() - for dict_name, scp in feats.items(): - feats_scp_path = os.path.join(split_directory, f"feats.{dict_name}.{self.name}.scp") - output_mapping(scp, feats_scp_path) - - cmvn = self.cmvn_scp_data() - for dict_name, scp in cmvn.items(): - cmvn_scp_path = os.path.join(split_directory, f"cmvn.{dict_name}.{self.name}.scp") - output_mapping(scp, cmvn_scp_path) - - utt2spk = self.utt2spk_scp_data() - for dict_name, scp in utt2spk.items(): - utt2spk_scp_path = os.path.join( - split_directory, f"utt2spk.{dict_name}.{self.name}.scp" - ) - output_mapping(scp, utt2spk_scp_path) - - segments = self.segments_scp_data() - for dict_name, scp in segments.items(): - segments_scp_path = os.path.join( - split_directory, f"segments.{dict_name}.{self.name}.scp" - ) - output_mapping(scp, segments_scp_path) - - text_scp = self.text_scp_data() - for dict_name, scp in text_scp.items(): - if not scp: - continue - text_scp_path = os.path.join(split_directory, f"text.{dict_name}.{self.name}.scp") - output_mapping(scp, text_scp_path) - - text_int = self.text_int_scp_data() - for dict_name, scp in text_int.items(): - if dict_name is None: - continue - if not scp: - continue - text_int_scp_path = os.path.join( - split_directory, f"text.{dict_name}.{self.name}.int.scp" - ) - output_mapping(scp, text_int_scp_path, skip_safe=True) diff --git a/montreal_forced_aligner/multiprocessing/corpus.py b/montreal_forced_aligner/multiprocessing/corpus.py deleted file mode 100644 index b3907088..00000000 --- a/montreal_forced_aligner/multiprocessing/corpus.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -Corpus loading worker ---------------------- - - -""" -from __future__ import annotations - -import multiprocessing as mp -import sys -import traceback -from queue import Empty -from typing import TYPE_CHECKING, Dict, Union - -from ..exceptions import TextGridParseError, TextParseError - -if TYPE_CHECKING: - from ..corpus import OneToManyMappingType, OneToOneMappingType - from ..corpus.base import SoundFileInfoDict - - FileInfoDict = Dict[ - str, Union[str, SoundFileInfoDict, OneToOneMappingType, OneToManyMappingType] - ] - from .helper import Stopped - - -__all__ = ["CorpusProcessWorker"] - - -class CorpusProcessWorker(mp.Process): - """ - Multiprocessing corpus loading worker - - Attributes - ---------- - job_q: :class:`~multiprocessing.Queue` - Job queue for files to process - return_dict: Dict - Dictionary to catch errors - return_q: :class:`~multiprocessing.Queue` - Return queue for processed Files - stopped: :func:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check for whether corpus loading should exit - finished_adding: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Signal that the main thread has stopped adding new files to be processed - """ - - def __init__( - self, - job_q: mp.Queue, - return_dict: Dict, - return_q: mp.Queue, - stopped: Stopped, - finished_adding: Stopped, - ): - mp.Process.__init__(self) - self.job_q = job_q - self.return_dict = return_dict - self.return_q = return_q - self.stopped = stopped - self.finished_adding = finished_adding - - def run(self) -> None: - """ - Run the corpus loading job - """ - from ..corpus.classes import parse_file - - while True: - try: - arguments = self.job_q.get(timeout=1) - except Empty: - if self.finished_adding.stop_check(): - break - continue - self.job_q.task_done() - if self.stopped.stop_check(): - continue - try: - file = parse_file(*arguments, stop_check=self.stopped.stop_check) - self.return_q.put(file) - except TextParseError as e: - self.return_dict["decode_error_files"].append(e) - except TextGridParseError as e: - self.return_dict["textgrid_read_errors"][e.file_name] = e - except Exception: - self.stopped.stop() - self.return_dict["error"] = arguments, Exception( - traceback.format_exception(*sys.exc_info()) - ) - return diff --git a/montreal_forced_aligner/multiprocessing/features.py b/montreal_forced_aligner/multiprocessing/features.py deleted file mode 100644 index a1f35209..00000000 --- a/montreal_forced_aligner/multiprocessing/features.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Feature generation functions ----------------------------- - -""" -from __future__ import annotations - -import os -import shutil -import subprocess -from typing import TYPE_CHECKING, Dict, List, Union - -from ..helper import load_scp, make_safe -from ..multiprocessing import run_mp, run_non_mp -from ..utils import thirdparty_binary - -if TYPE_CHECKING: - SpeakerCharacterType = Union[str, int] - from ..abc import MetaDict - from ..corpus import Corpus - -__all__ = ["mfcc", "compute_vad", "calc_cmvn", "mfcc_func", "compute_vad_func"] - - -def mfcc_func( - log_path: str, - dictionaries: List[str], - feats_scp_paths: Dict[str, str], - lengths_paths: Dict[str, str], - segment_paths: Dict[str, str], - wav_paths: Dict[str, str], - mfcc_options: MetaDict, -) -> None: - """ - Multiprocessing function for generating MFCC features - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feats_scp_paths: Dict[str, str] - Dictionary of feature scp files per dictionary name - lengths_paths: Dict[str, str] - Dictionary of feature lengths files per dictionary name - segment_paths: Dict[str, str] - Dictionary of segment scp files per dictionary name - wav_paths: Dict[str, str] - Dictionary of sound file scp files per dictionary name - mfcc_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for MFCC generation - """ - with open(log_path, "w") as log_file: - for dict_name in dictionaries: - mfcc_base_command = [thirdparty_binary("compute-mfcc-feats"), "--verbose=2"] - raw_ark_path = feats_scp_paths[dict_name].replace(".scp", ".ark") - for k, v in mfcc_options.items(): - mfcc_base_command.append(f"--{k.replace('_', '-')}={make_safe(v)}") - if os.path.exists(segment_paths[dict_name]): - mfcc_base_command += ["ark:-", "ark:-"] - seg_proc = subprocess.Popen( - [ - thirdparty_binary("extract-segments"), - f"scp,p:{wav_paths[dict_name]}", - segment_paths[dict_name], - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - comp_proc = subprocess.Popen( - mfcc_base_command, - stdout=subprocess.PIPE, - stderr=log_file, - stdin=seg_proc.stdout, - env=os.environ, - ) - else: - mfcc_base_command += [f"scp,p:{wav_paths[dict_name]}", "ark:-"] - comp_proc = subprocess.Popen( - mfcc_base_command, stdout=subprocess.PIPE, stderr=log_file, env=os.environ - ) - copy_proc = subprocess.Popen( - [ - thirdparty_binary("copy-feats"), - "--compress=true", - "ark:-", - f"ark,scp:{raw_ark_path},{feats_scp_paths[dict_name]}", - ], - stdin=comp_proc.stdout, - stderr=log_file, - env=os.environ, - ) - copy_proc.communicate() - - utt_lengths_proc = subprocess.Popen( - [ - thirdparty_binary("feat-to-len"), - f"scp:{feats_scp_paths[dict_name]}", - f"ark,t:{lengths_paths[dict_name]}", - ], - stderr=log_file, - env=os.environ, - ) - utt_lengths_proc.communicate() - - -def mfcc(corpus: Corpus) -> None: - """ - Multiprocessing function that converts sound files into MFCCs - - See http://kaldi-asr.org/doc/feat.html and - http://kaldi-asr.org/doc/compute-mfcc-feats_8cc.html for more details on how - MFCCs are computed. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/make_mfcc.sh - for the bash script this function was based on. - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to generate MFCC features for - """ - log_directory = os.path.join(corpus.split_directory, "log") - os.makedirs(log_directory, exist_ok=True) - - jobs = [job.mfcc_arguments(corpus) for job in corpus.jobs] - if corpus.use_mp: - run_mp(mfcc_func, jobs, log_directory) - else: - run_non_mp(mfcc_func, jobs, log_directory) - - -def calc_cmvn(corpus: Corpus) -> None: - """ - Calculate CMVN statistics for speakers - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to run CMVN calculation - """ - spk2utt = os.path.join(corpus.output_directory, "spk2utt.scp") - feats = os.path.join(corpus.output_directory, "feats.scp") - cmvn_directory = os.path.join(corpus.features_directory, "cmvn") - os.makedirs(cmvn_directory, exist_ok=True) - cmvn_ark = os.path.join(cmvn_directory, "cmvn.ark") - cmvn_scp = os.path.join(cmvn_directory, "cmvn.scp") - log_path = os.path.join(cmvn_directory, "cmvn.log") - with open(log_path, "w") as logf: - subprocess.call( - [ - thirdparty_binary("compute-cmvn-stats"), - f"--spk2utt=ark:{spk2utt}", - f"scp:{feats}", - f"ark,scp:{cmvn_ark},{cmvn_scp}", - ], - stderr=logf, - env=os.environ, - ) - shutil.copy(cmvn_scp, os.path.join(corpus.output_directory, "cmvn.scp")) - for s, cmvn in load_scp(cmvn_scp).items(): - corpus.speakers[s].cmvn = cmvn - corpus.split() - - -def compute_vad_func( - log_path: str, - dictionaries: List[str], - feats_scp_paths: Dict[str, str], - vad_scp_paths: Dict[str, str], - vad_options: MetaDict, -) -> None: - """ - Multiprocessing function to compute voice activity detection - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feats_scp_paths: Dict[str, str] - PronunciationDictionary of feature scp files per dictionary name - vad_scp_paths: Dict[str, str] - PronunciationDictionary of vad scp files per dictionary name - vad_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for VAD - """ - with open(log_path, "w") as log_file: - for dict_name in dictionaries: - feats_scp_path = feats_scp_paths[dict_name] - vad_scp_path = vad_scp_paths[dict_name] - vad_proc = subprocess.Popen( - [ - thirdparty_binary("compute-vad"), - f"--vad-energy-mean-scale={vad_options['energy_mean_scale']}", - f"--vad-energy-threshold={vad_options['energy_threshold']}", - "scp:" + feats_scp_path, - f"ark,t:{vad_scp_path}", - ], - stderr=log_file, - env=os.environ, - ) - vad_proc.communicate() - - -def compute_vad(corpus: Corpus) -> None: - """ - Compute VAD for a corpus - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to compute VAD - """ - log_directory = os.path.join(corpus.split_directory, "log") - os.makedirs(log_directory, exist_ok=True) - jobs = [x.vad_arguments(corpus) for x in corpus.jobs] - if corpus.use_mp: - run_mp(compute_vad_func, jobs, log_directory) - else: - run_non_mp(compute_vad_func, jobs, log_directory) diff --git a/montreal_forced_aligner/multiprocessing/helper.py b/montreal_forced_aligner/multiprocessing/helper.py deleted file mode 100644 index bfb32979..00000000 --- a/montreal_forced_aligner/multiprocessing/helper.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -Multiprocessing helpers ------------------------ - -""" -from __future__ import annotations - -import multiprocessing as mp -import os -import sys -import traceback -from queue import Empty -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -from ..utils import parse_logs - -__all__ = ["Counter", "Stopped", "ProcessWorker", "run_mp", "run_non_mp"] - - -class Counter(object): - """ - Multiprocessing counter object for keeping track of progress - - Attributes - ---------- - val: :func:`~multiprocessing.Value` - Integer to increment - lock: :class:`~multiprocessing.Lock` - Lock for process safety - """ - - def __init__(self, init_val: int = 0): - self.val = mp.Value("i", init_val) - self.lock = mp.Lock() - - def increment(self) -> None: - """Increment the counter""" - with self.lock: - self.val.value += 1 - - def value(self) -> int: - """Get the current value of the counter""" - with self.lock: - return self.val.value - - -class Stopped(object): - """ - Multiprocessing class for detecting whether processes should stop processing and exit ASAP - - Attributes - ---------- - val: :func:`~multiprocessing.Value` - 0 if not stopped, 1 if stopped - lock: :class:`~multiprocessing.Lock` - Lock for process safety - _source: multiprocessing.Value - 1 if it was a Ctrl+C event that stopped it, 0 otherwise - """ - - def __init__(self, initval: Union[bool, int] = False): - self.val = mp.Value("i", initval) - self.lock = mp.Lock() - self._source = mp.Value("i", 0) - - def stop(self) -> None: - """Signal that work should stop asap""" - with self.lock: - self.val.value = True - - def stop_check(self) -> int: - """Check whether a process should stop""" - with self.lock: - return self.val.value - - def set_sigint_source(self) -> None: - """Set the source as a ctrl+c""" - with self.lock: - self._source.value = True - - def source(self) -> int: - """Get the source value""" - with self.lock: - return self._source.value - - -class ProcessWorker(mp.Process): - """ - Multiprocessing function work - - Parameters - ---------- - job_name: int - Integer number of job - job_q: :class:`~multiprocessing.Queue` - Job queue to pull arguments from - function: Callable - Multiprocessing function to call on arguments from job_q - return_dict: Dict - Dictionary for collecting errors - stopped: :class:`~montreal_forced_aligner.multiprocessing.helper.Stopped` - Stop check - return_info: Dict[int, Any], optional - Optional dictionary to fill if the function should return information to main thread - """ - - def __init__( - self, - job_name: int, - job_q: mp.Queue, - function: Callable, - return_dict: Dict, - stopped: Stopped, - return_info: Optional[Dict[int, Any]] = None, - ): - mp.Process.__init__(self) - self.job_name = job_name - self.function = function - self.job_q = job_q - self.return_dict = return_dict - self.return_info = return_info - self.stopped = stopped - - def run(self) -> None: - """ - Run through the arguments in the queue apply the function to them - """ - try: - arguments = self.job_q.get(timeout=1) - except Empty: - return - self.job_q.task_done() - try: - result = self.function(*arguments) - if self.return_info is not None: - self.return_info[self.job_name] = result - except Exception: - self.stopped.stop() - self.return_dict["error"] = arguments, Exception( - traceback.format_exception(*sys.exc_info()) - ) - - -def run_non_mp( - function: Callable, - argument_list: List[Tuple[Any, ...]], - log_directory: str, - return_info: bool = False, -) -> Optional[Dict[Any, Any]]: - """ - Similar to run_mp, but no additional processes are used and the jobs are evaluated in sequential order - - Parameters - ---------- - function: Callable - Multiprocessing function to evaluate - argument_list: List - List of arguments to process - log_directory: str - Directory that all log information from the processes goes to - return_info: Dict, optional - If the function returns information, supply the return dict to populate - - Returns - ------- - Dict, optional - If the function returns information, returns the dictionary it was supplied with - """ - if return_info: - info = {} - for i, args in enumerate(argument_list): - info[i] = function(*args) - parse_logs(log_directory) - return info - - for args in argument_list: - function(*args) - parse_logs(log_directory) - - -def run_mp( - function: Callable, - argument_list: List[Tuple[Any, ...]], - log_directory: str, - return_info: bool = False, -) -> Optional[Dict[int, Any]]: - """ - Apply a function for each job in parallel - - Parameters - ---------- - function: Callable - Multiprocessing function to apply - argument_list: List - List of arguments for each job - log_directory: str - Directory that all log information from the processes goes to - return_info: Dict, optional - If the function returns information, supply the return dict to populate - """ - from ..config import BLAS_THREADS - - os.environ["OPENBLAS_NUM_THREADS"] = f"{BLAS_THREADS}" - os.environ["MKL_NUM_THREADS"] = f"{BLAS_THREADS}" - stopped = Stopped() - manager = mp.Manager() - job_queue = manager.Queue() - return_dict = manager.dict() - info = None - if return_info: - info = manager.dict() - for a in argument_list: - job_queue.put(a) - procs = [] - for i in range(len(argument_list)): - p = ProcessWorker(i, job_queue, function, return_dict, stopped, info) - procs.append(p) - p.start() - - for p in procs: - p.join() - if "error" in return_dict: - _, exc = return_dict["error"] - raise exc - - parse_logs(log_directory) - if return_info: - return info diff --git a/montreal_forced_aligner/multiprocessing/ivector.py b/montreal_forced_aligner/multiprocessing/ivector.py deleted file mode 100644 index 7051fde6..00000000 --- a/montreal_forced_aligner/multiprocessing/ivector.py +++ /dev/null @@ -1,790 +0,0 @@ -""" -Ivector extractor functions ---------------------------- - - -""" -from __future__ import annotations - -import os -import subprocess -from typing import TYPE_CHECKING, Dict, List, Union - -from ..abc import MetaDict -from ..helper import load_scp -from ..utils import thirdparty_binary -from .helper import run_mp, run_non_mp - -if TYPE_CHECKING: - from ..abc import IvectorExtractor - from ..corpus.classes import File, Speaker, Utterance # noqa - from ..segmenter import SegmentationType, Segmenter - from ..trainers.ivector_extractor import IvectorExtractorTrainer - - -__all__ = [ - "gmm_gselect", - "acc_global_stats", - "gauss_to_post", - "acc_ivector_stats", - "extract_ivectors", - "get_initial_segmentation", - "merge_segments", - "segment_vad", - "segment_vad_func", - "gmm_gselect_func", - "gauss_to_post_func", - "acc_global_stats_func", - "acc_ivector_stats_func", - "extract_ivectors_func", -] - - -def gmm_gselect_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ivector_options: MetaDict, - dubm_path: str, - gselect_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function for running gmm-gselect - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - Dictionary of feature strings per dictionary name - ivector_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for ivector extractor training - dubm_path: str - Path to the DUBM file - gselect_paths: Dict[str, str] - Dictionary of gselect archives per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - gselect_path = gselect_paths[dict_name] - subsample_feats_proc = subprocess.Popen( - [ - thirdparty_binary("subsample-feats"), - f"--n={ivector_options['subsample']}", - feature_string, - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - - gselect_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-gselect"), - f"--n={ivector_options['num_gselect']}", - dubm_path, - "ark:-", - f"ark:{gselect_path}", - ], - stdin=subsample_feats_proc.stdout, - stderr=log_file, - env=os.environ, - ) - gselect_proc.communicate() - - -def gmm_gselect(trainer: IvectorExtractorTrainer) -> None: - """ - Multiprocessing function that stores Gaussian selection indices on disk - - See: - - - http://kaldi-asr.org/doc/gmm-gselect_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_diag_ubm.sh - for the original bash script that this function was based on. - - Parameters - ---------- - trainer : :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Ivector Extractor Trainer - """ - jobs = [x.gmm_gselect_arguments(trainer) for x in trainer.corpus.jobs] - if trainer.use_mp: - run_mp(gmm_gselect_func, jobs, trainer.working_log_directory) - else: - run_non_mp(gmm_gselect_func, jobs, trainer.working_log_directory) - - -def acc_global_stats_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ivector_options: MetaDict, - gselect_paths: Dict[str, str], - acc_paths: Dict[str, str], - dubm_path: str, -) -> None: - """ - Multiprocessing function for accumulating global model stats - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - Dictionary of feature strings per dictionary name - ivector_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for ivector extractor training - gselect_paths: Dict[str, str] - Dictionary of gselect archives per dictionary name - acc_paths: Dict[str, str] - Dictionary of accumulated stats files per dictionary name - dubm_path: str - Path to the DUBM file - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - gselect_path = gselect_paths[dict_name] - acc_path = acc_paths[dict_name] - subsample_feats_proc = subprocess.Popen( - [ - thirdparty_binary("subsample-feats"), - f"--n={ivector_options['subsample']}", - feature_string, - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - gmm_global_acc_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-global-acc-stats"), - f"--gselect=ark:{gselect_path}", - dubm_path, - "ark:-", - acc_path, - ], - stderr=log_file, - stdin=subsample_feats_proc.stdout, - env=os.environ, - ) - gmm_global_acc_proc.communicate() - - -def acc_global_stats(trainer: IvectorExtractorTrainer) -> None: - """ - Multiprocessing function that accumulates global GMM stats - - See: - - - http://kaldi-asr.org/doc/gmm-global-acc-stats_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_diag_ubm.sh - for the original bash script that this function was based on. - - Parameters - ---------- - trainer : :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Ivector Extractor Trainer - """ - jobs = [x.acc_global_stats_arguments(trainer) for x in trainer.corpus.jobs] - if trainer.use_mp: - run_mp(acc_global_stats_func, jobs, trainer.working_log_directory) - else: - run_non_mp(acc_global_stats_func, jobs, trainer.working_log_directory) - - # Don't remove low-count Gaussians till the last tier, - # or gselect info won't be valid anymore - if trainer.iteration < trainer.ubm_num_iterations: - opt = "--remove-low-count-gaussians=false" - else: - opt = f"--remove-low-count-gaussians={trainer.ubm_remove_low_count_gaussians}" - log_path = os.path.join(trainer.working_log_directory, f"update.{trainer.iteration}.log") - with open(log_path, "w") as log_file: - acc_files = [] - for j in jobs: - acc_files.extend(j.acc_paths.values()) - sum_proc = subprocess.Popen( - [thirdparty_binary("gmm-global-sum-accs"), "-"] + acc_files, - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - gmm_global_est_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-global-est"), - opt, - f"--min-gaussian-weight={trainer.ubm_min_gaussian_weight}", - trainer.current_dubm_path, - "-", - trainer.next_dubm_path, - ], - stderr=log_file, - stdin=sum_proc.stdout, - env=os.environ, - ) - gmm_global_est_proc.communicate() - # Clean up - if not trainer.debug: - for p in acc_files: - os.remove(p) - - -def gauss_to_post_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ivector_options: MetaDict, - post_paths: Dict[str, str], - dubm_path: str, -): - """ - Multiprocessing function to get posteriors during UBM training - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - Dictionary of feature strings per dictionary name - ivector_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for ivector extractor training - post_paths: Dict[str, str] - Dictionary of posterior archives per dictionary name - dubm_path: str - Path to the DUBM file - """ - modified_posterior_scale = ivector_options["posterior_scale"] * ivector_options["subsample"] - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - post_path = post_paths[dict_name] - subsample_feats_proc = subprocess.Popen( - [ - thirdparty_binary("subsample-feats"), - f"--n={ivector_options['subsample']}", - feature_string, - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - gmm_global_get_post_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-global-get-post"), - f"--n={ivector_options['num_gselect']}", - f"--min-post={ivector_options['min_post']}", - dubm_path, - "ark:-", - "ark:-", - ], - stdout=subprocess.PIPE, - stdin=subsample_feats_proc.stdout, - stderr=log_file, - env=os.environ, - ) - scale_post_proc = subprocess.Popen( - [ - thirdparty_binary("scale-post"), - "ark:-", - str(modified_posterior_scale), - f"ark:{post_path}", - ], - stdin=gmm_global_get_post_proc.stdout, - stderr=log_file, - env=os.environ, - ) - scale_post_proc.communicate() - - -def gauss_to_post(trainer: IvectorExtractorTrainer) -> None: - """ - Multiprocessing function that does Gaussian selection and posterior extraction - - See: - - - http://kaldi-asr.org/doc/gmm-global-get-post_8cc.html - - http://kaldi-asr.org/doc/scale-post_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/online/nnet2/train_ivector_extractor.sh - for the original bash script that this function was based on. - - Parameters - ---------- - trainer: :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Ivector Extractor Trainer - """ - jobs = [x.gauss_to_post_arguments(trainer) for x in trainer.corpus.jobs] - if trainer.use_mp: - run_mp(gauss_to_post_func, jobs, trainer.working_log_directory) - else: - run_non_mp(gauss_to_post_func, jobs, trainer.working_log_directory) - - -def acc_ivector_stats_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ivector_options: MetaDict, - ie_path: str, - post_paths: Dict[str, str], - acc_init_paths: Dict[str, str], -) -> None: - """ - Multiprocessing function that accumulates stats for ivector training - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - PronunciationDictionary of feature strings per dictionary name - ivector_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for ivector extractor training - ie_path: str - Path to the ivector extractor file - post_paths: Dict[str, str] - PronunciationDictionary of posterior archives per dictionary name - acc_init_paths: Dict[str, str] - PronunciationDictionary of accumulated stats files per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - feature_string = feature_strings[dict_name] - post_path = post_paths[dict_name] - acc_init_path = acc_init_paths[dict_name] - subsample_feats_proc = subprocess.Popen( - [ - thirdparty_binary("subsample-feats"), - f"--n={ivector_options['subsample']}", - feature_string, - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - acc_stats_proc = subprocess.Popen( - [ - thirdparty_binary("ivector-extractor-acc-stats"), - "--num-threads=1", - ie_path, - "ark:-", - f"ark:{post_path}", - acc_init_path, - ], - stdin=subsample_feats_proc.stdout, - stderr=log_file, - env=os.environ, - ) - acc_stats_proc.communicate() - - -def acc_ivector_stats(trainer: IvectorExtractorTrainer) -> None: - """ - Multiprocessing function that calculates job_name-vector extractor stats - - See: - - - http://kaldi-asr.org/doc/ivector-extractor-acc-stats_8cc.html - - http://kaldi-asr.org/doc/ivector-extractor-sum-accs_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/online/nnet2/train_ivector_extractor.sh - for the original bash script that this function was based on. - - Parameters - ---------- - trainer: :class:`~montreal_forced_aligner.trainers.IvectorExtractorTrainer` - Ivector Extractor Trainer - """ - - jobs = [x.ivector_acc_stats_arguments(trainer) for x in trainer.corpus.jobs] - if trainer.use_mp: - run_mp(acc_ivector_stats_func, jobs, trainer.working_log_directory) - else: - run_non_mp(acc_ivector_stats_func, jobs, trainer.working_log_directory) - - log_path = os.path.join(trainer.working_log_directory, f"sum_acc.{trainer.iteration}.log") - acc_path = os.path.join(trainer.working_directory, f"acc.{trainer.iteration}") - with open(log_path, "w", encoding="utf8") as log_file: - accinits = [] - for j in jobs: - accinits.extend(j.acc_init_paths.values()) - sum_accs_proc = subprocess.Popen( - [thirdparty_binary("ivector-extractor-sum-accs"), "--parallel=true"] - + accinits - + [acc_path], - stderr=log_file, - env=os.environ, - ) - - sum_accs_proc.communicate() - # clean up - for p in accinits: - os.remove(p) - # Est extractor - log_path = os.path.join(trainer.working_log_directory, f"update.{trainer.iteration}.log") - with open(log_path, "w") as log_file: - extractor_est_proc = subprocess.Popen( - [ - thirdparty_binary("ivector-extractor-est"), - f"--num-threads={trainer.corpus.num_jobs}", - f"--gaussian-min-count={trainer.gaussian_min_count}", - trainer.current_ie_path, - os.path.join(trainer.working_directory, f"acc.{trainer.iteration}"), - trainer.next_ie_path, - ], - stderr=log_file, - env=os.environ, - ) - extractor_est_proc.communicate() - - -def extract_ivectors_func( - log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], - ivector_options: MetaDict, - ali_paths: Dict[str, str], - ie_path: str, - ivector_paths: Dict[str, str], - weight_paths: Dict[str, str], - model_path: str, - dubm_path: str, -) -> None: - """ - Multiprocessing function for extracting ivectors - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - feature_strings: Dict[str, str] - Dictionary of feature strings per dictionary name - ivector_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for ivector extraction - ali_paths: Dict[str, str] - Dictionary of alignment archives per dictionary name - ie_path: str - Path to the ivector extractor file - ivector_paths: Dict[str, str] - Dictionary of ivector archives per dictionary name - weight_paths: Dict[str, str] - Dictionary of weighted archives per dictionary name - model_path: str - Path to the acoustic model file - dubm_path: str - Path to the DUBM file - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - ali_path = ali_paths[dict_name] - weight_path = weight_paths[dict_name] - ivectors_path = ivector_paths[dict_name] - feature_string = feature_strings[dict_name] - use_align = os.path.exists(ali_path) - if use_align: - ali_to_post_proc = subprocess.Popen( - [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], - stderr=log_file, - stdout=subprocess.PIPE, - env=os.environ, - ) - weight_silence_proc = subprocess.Popen( - [ - thirdparty_binary("weight-silence-post"), - str(ivector_options["silence_weight"]), - ivector_options["sil_phones"], - model_path, - "ark:-", - "ark:-", - ], - stderr=log_file, - stdin=ali_to_post_proc.stdout, - stdout=subprocess.PIPE, - env=os.environ, - ) - post_to_weight_proc = subprocess.Popen( - [thirdparty_binary("post-to-weights"), "ark:-", f"ark:{weight_path}"], - stderr=log_file, - stdin=weight_silence_proc.stdout, - env=os.environ, - ) - post_to_weight_proc.communicate() - - gmm_global_get_post_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-global-get-post"), - f"--n={ivector_options['num_gselect']}", - f"--min-post={ivector_options['min_post']}", - dubm_path, - feature_string, - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - if use_align: - weight_proc = subprocess.Popen( - [ - thirdparty_binary("weight-post"), - "ark:-", - f"ark,s,cs:{weight_path}", - "ark:-", - ], - stdin=gmm_global_get_post_proc.stdout, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - extract_in = weight_proc.stdout - else: - extract_in = gmm_global_get_post_proc.stdout - extract_proc = subprocess.Popen( - [ - thirdparty_binary("ivector-extract"), - f"--acoustic-weight={ivector_options['posterior_scale']}", - "--compute-objf-change=true", - f"--max-count={ivector_options['max_count']}", - ie_path, - feature_string, - "ark,s,cs:-", - f"ark,t:{ivectors_path}", - ], - stderr=log_file, - stdin=extract_in, - env=os.environ, - ) - extract_proc.communicate() - - -def extract_ivectors(ivector_extractor: IvectorExtractor) -> None: - """ - Multiprocessing function that extracts job_name-vectors. - - See: - - - http://kaldi-asr.org/doc/ivector-extract-online2_8cc.html - - http://kaldi-asr.org/doc/copy-feats_8cc.html - - for more details - on the Kaldi binary this runs. - - Also see https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh - for the original bash script that this function was based on. - - Parameters - ---------- - ivector_extractor: IvectorExtractor - Ivector extractor - """ - - log_dir = ivector_extractor.log_directory - os.makedirs(log_dir, exist_ok=True) - - jobs = [x.extract_ivector_arguments(ivector_extractor) for x in ivector_extractor.corpus.jobs] - if ivector_extractor.use_mp: - run_mp(extract_ivectors_func, jobs, log_dir) - else: - run_non_mp(extract_ivectors_func, jobs, log_dir) - - -def get_initial_segmentation(frames: List[Union[int, str]], frame_shift: int) -> SegmentationType: - """ - Compute initial segmentation over voice activity - - Parameters - ---------- - frames: List[Union[int, str]] - List of frames with VAD output - frame_shift: int - Frame shift of features in ms - - Returns - ------- - SegmentationType - Initial segmentation - """ - segs = [] - cur_seg = None - silent_frames = 0 - non_silent_frames = 0 - for i, f in enumerate(frames): - if int(f) > 0: - non_silent_frames += 1 - if cur_seg is None: - cur_seg = {"begin": i * frame_shift} - else: - silent_frames += 1 - if cur_seg is not None: - cur_seg["end"] = (i - 1) * frame_shift - segs.append(cur_seg) - cur_seg = None - if cur_seg is not None: - cur_seg["end"] = len(frames) * frame_shift - segs.append(cur_seg) - return segs - - -def merge_segments( - segments: SegmentationType, - min_pause_duration: float, - max_segment_length: float, - snap_boundary_threshold: float, -) -> SegmentationType: - """ - Merge segments together - - Parameters - ---------- - segments: SegmentationType - Initial segments - min_pause_duration: float - Minimum amount of silence time to mark an utterance boundary - max_segment_length: float - Maximum length of segments before they're broken up - snap_boundary_threshold: - Boundary threshold to snap boundaries together - - Returns - ------- - SegmentationType - Merged segments - """ - merged_segs = [] - for s in segments: - if ( - not merged_segs - or s["begin"] > merged_segs[-1]["end"] + min_pause_duration - or s["end"] - merged_segs[-1]["begin"] > max_segment_length - ): - if s["end"] - s["begin"] > min_pause_duration: - if merged_segs and snap_boundary_threshold: - boundary_gap = s["begin"] - merged_segs[-1]["end"] - if boundary_gap < snap_boundary_threshold: - half_boundary = boundary_gap / 2 - else: - half_boundary = snap_boundary_threshold / 2 - merged_segs[-1]["end"] += half_boundary - s["begin"] -= half_boundary - - merged_segs.append(s) - else: - merged_segs[-1]["end"] = s["end"] - return merged_segs - - -def segment_vad_func( - dictionaries: List[str], - vad_paths: Dict[str, str], - segmentation_options: MetaDict, -) -> Dict[str, Utterance]: - """ - Multiprocessing function to generate segments from VAD output - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - vad_paths: Dict[str, str] - Dictionary of VAD archives per dictionary name - segmentation_options: :class:`~montreal_forced_aligner.abc.MetaDict` - Options for segmentation - """ - - utterances = {} - from ..corpus.classes import File, Speaker, Utterance # noqa - - speaker = Speaker("speech") - for dict_name in dictionaries: - vad_path = vad_paths[dict_name] - - vad = load_scp(vad_path, data_type=int) - for recording, frames in vad.items(): - file = File(recording) - initial_segments = get_initial_segmentation( - frames, segmentation_options["frame_shift"] - ) - merged = merge_segments( - initial_segments, - segmentation_options["min_pause_duration"], - segmentation_options["max_segment_length"], - segmentation_options["snap_boundary_threshold"], - ) - for seg in merged: - utterances[recording] = Utterance( - speaker, file, begin=seg["begin"], end=seg["end"], text="speech" - ) - return utterances - - -def segment_vad(segmenter: Segmenter) -> None: - """ - Run segmentation based off of VAD - - Parameters - ---------- - segmenter: :class:`~montreal_forced_aligner.segmenter.Segmenter` - Segmenter - """ - - from ..corpus.classes import Speaker # noqa - - jobs = [x.segments_vad_arguments(segmenter) for x in segmenter.corpus.jobs] - if segmenter.segmentation_config.use_mp: - segment_info = run_mp( - segment_vad_func, jobs, segmenter.corpus.features_log_directory, True - ) - else: - segment_info = run_non_mp( - segment_vad_func, jobs, segmenter.corpus.features_log_directory, True - ) - for j in segmenter.corpus.jobs: - for old_utt, utterance in segment_info[j.name].items(): - old_utt = segmenter.corpus.utterances[old_utt] - file = old_utt.file - if segmenter.corpus.no_speakers: - if utterance.speaker_name not in segmenter.corpus.speakers: - segmenter.corpus.speakers[utterance.speaker_name] = Speaker( - utterance.speaker_name - ) - speaker = segmenter.corpus.speakers[utterance.speaker_name] - else: - speaker = old_utt.speaker - utterance.file = file - utterance.set_speaker(speaker) - segmenter.corpus.add_utterance(utterance) - utterance_ids = [x.name for x in segmenter.corpus.utterances.values() if x.begin is None] - for u in utterance_ids: - segmenter.corpus.delete_utterance(u) diff --git a/montreal_forced_aligner/multiprocessing/pronunciations.py b/montreal_forced_aligner/multiprocessing/pronunciations.py deleted file mode 100644 index f84fb334..00000000 --- a/montreal_forced_aligner/multiprocessing/pronunciations.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Pronunciation probability functions ------------------------------------ - -""" -from __future__ import annotations - -import os -import subprocess -from collections import Counter, defaultdict -from typing import Dict, List, Tuple - -from ..abc import Aligner -from ..utils import thirdparty_binary -from .helper import run_mp, run_non_mp - -__all__ = ["generate_pronunciations", "generate_pronunciations_func"] - - -def generate_pronunciations_func( - log_path: str, - dictionaries: List[str], - text_int_paths: Dict[str, str], - word_boundary_paths: Dict[str, str], - ali_paths: Dict[str, str], - model_path: str, - pron_paths: Dict[str, str], -): - """ - Multiprocessing function for generating pronunciations - - Parameters - ---------- - log_path: str - Path to save log output - dictionaries: List[str] - List of dictionary names - text_int_paths: Dict[str, str] - Dictionary of text int files per dictionary name - word_boundary_paths: Dict[str, str] - Dictionary of word boundary files per dictionary name - ali_paths: Dict[str, str] - Dictionary of alignment archives per dictionary name - model_path: str - Path to acoustic model file - pron_paths: Dict[str, str] - Dictionary of pronunciation archives per dictionary name - """ - with open(log_path, "w", encoding="utf8") as log_file: - for dict_name in dictionaries: - text_int_path = text_int_paths[dict_name] - word_boundary_path = word_boundary_paths[dict_name] - ali_path = ali_paths[dict_name] - pron_path = pron_paths[dict_name] - - lin_proc = subprocess.Popen( - [ - thirdparty_binary("linear-to-nbest"), - f"ark:{ali_path}", - f"ark:{text_int_path}", - "", - "", - "ark:-", - ], - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - align_proc = subprocess.Popen( - [ - thirdparty_binary("lattice-align-words"), - word_boundary_path, - model_path, - "ark:-", - "ark:-", - ], - stdin=lin_proc.stdout, - stdout=subprocess.PIPE, - stderr=log_file, - env=os.environ, - ) - - prons_proc = subprocess.Popen( - [thirdparty_binary("nbest-to-prons"), model_path, "ark:-", pron_path], - stdin=align_proc.stdout, - stderr=log_file, - env=os.environ, - ) - prons_proc.communicate() - - -def generate_pronunciations( - aligner: Aligner, -) -> Tuple[Dict[str, defaultdict[Counter]], Dict[str, Dict[str, List[str, ...]]]]: - """ - Generates pronunciations based on alignments for a corpus and calculates pronunciation probabilities - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.aligner.pretrained.PretrainedAligner` - Aligner - """ - jobs = [x.generate_pronunciations_arguments(aligner) for x in aligner.corpus.jobs] - if aligner.align_config.use_mp: - run_mp(generate_pronunciations_func, jobs, aligner.working_log_directory) - else: - run_non_mp(generate_pronunciations_func, jobs, aligner.working_log_directory) - pron_counts = {} - utt_mapping = {} - for j in aligner.corpus.jobs: - args = j.generate_pronunciations_arguments(aligner) - dict_data = j.dictionary_data() - for dict_name, pron_path in args.pron_paths.items(): - if dict_name not in pron_counts: - pron_counts[dict_name] = defaultdict(Counter) - utt_mapping[dict_name] = {} - word_lookup = dict_data[dict_name].reversed_words_mapping - phone_lookup = dict_data[dict_name].reversed_phone_mapping - with open(pron_path, "r", encoding="utf8") as f: - last_utt = None - for line in f: - line = line.split() - utt = line[0] - if utt not in utt_mapping[dict_name]: - if last_utt is not None: - utt_mapping[dict_name][last_utt].append("") - utt_mapping[dict_name][utt] = [""] - last_utt = utt - - word = word_lookup[int(line[3])] - if word == "": - utt_mapping[dict_name][utt].append(word) - else: - pron = tuple(phone_lookup[int(x)].split("_")[0] for x in line[4:]) - pron_string = " ".join(pron) - utt_mapping[dict_name][utt].append(word + " " + pron_string) - pron_counts[dict_name][word][pron] += 1 - return pron_counts, utt_mapping diff --git a/montreal_forced_aligner/segmenter.py b/montreal_forced_aligner/segmenter.py index 66be4159..91baa2c4 100644 --- a/montreal_forced_aligner/segmenter.py +++ b/montreal_forced_aligner/segmenter.py @@ -10,115 +10,314 @@ from __future__ import annotations import os -import shutil -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional, Union -from .abc import MetaDict -from .config import TEMP_DIR +import yaml + +from .abc import FileExporterMixin, MetaDict, TopLevelMfaWorker +from .corpus.acoustic_corpus import AcousticCorpusMixin +from .corpus.classes import File, Speaker, Utterance +from .corpus.features import VadConfigMixin from .exceptions import KaldiProcessingError -from .multiprocessing.ivector import segment_vad -from .utils import log_kaldi_errors, parse_logs +from .helper import load_scp +from .utils import log_kaldi_errors, parse_logs, run_mp, run_non_mp if TYPE_CHECKING: - from logging import Logger - - from .config import SegmentationConfig - from .corpus import Corpus + from argparse import Namespace -SegmentationType = List[Dict[str, float]] +SegmentationType = list[dict[str, float]] __all__ = ["Segmenter"] -class Segmenter: +class SegmentVadArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.segmenter.segment_vad_func`""" + + dictionaries: list[str] + vad_paths: dict[str, str] + segmentation_options: MetaDict + + +def get_initial_segmentation(frames: list[Union[int, str]], frame_shift: int) -> SegmentationType: + """ + Compute initial segmentation over voice activity + + Parameters + ---------- + frames: list[Union[int, str]] + List of frames with VAD output + frame_shift: int + Frame shift of features in ms + + Returns + ------- + SegmentationType + Initial segmentation + """ + segs = [] + cur_seg = None + silent_frames = 0 + non_silent_frames = 0 + for i, f in enumerate(frames): + if int(f) > 0: + non_silent_frames += 1 + if cur_seg is None: + cur_seg = {"begin": i * frame_shift} + else: + silent_frames += 1 + if cur_seg is not None: + cur_seg["end"] = (i - 1) * frame_shift + segs.append(cur_seg) + cur_seg = None + if cur_seg is not None: + cur_seg["end"] = len(frames) * frame_shift + segs.append(cur_seg) + return segs + + +def merge_segments( + segments: SegmentationType, + min_pause_duration: float, + max_segment_length: float, + snap_boundary_threshold: float, +) -> SegmentationType: + """ + Merge segments together + + Parameters + ---------- + segments: SegmentationType + Initial segments + min_pause_duration: float + Minimum amount of silence time to mark an utterance boundary + max_segment_length: float + Maximum length of segments before they're broken up + snap_boundary_threshold: + Boundary threshold to snap boundaries together + + Returns + ------- + SegmentationType + Merged segments + """ + merged_segs = [] + for s in segments: + if ( + not merged_segs + or s["begin"] > merged_segs[-1]["end"] + min_pause_duration + or s["end"] - merged_segs[-1]["begin"] > max_segment_length + ): + if s["end"] - s["begin"] > min_pause_duration: + if merged_segs and snap_boundary_threshold: + boundary_gap = s["begin"] - merged_segs[-1]["end"] + if boundary_gap < snap_boundary_threshold: + half_boundary = boundary_gap / 2 + else: + half_boundary = snap_boundary_threshold / 2 + merged_segs[-1]["end"] += half_boundary + s["begin"] -= half_boundary + + merged_segs.append(s) + else: + merged_segs[-1]["end"] = s["end"] + return merged_segs + + +def segment_vad_func( + dictionaries: list[str], + vad_paths: dict[str, str], + segmentation_options: MetaDict, +) -> dict[str, Utterance]: + """ + Multiprocessing function to generate segments from VAD output. + + See Also + -------- + :meth:`montreal_forced_aligner.segmenter.Segmenter.segment_vad` + Main function that calls this function in parallel + :meth:`montreal_forced_aligner.segmenter.Segmenter.segment_vad_arguments` + Job method for generating arguments for this function + :kaldi_utils:`segmentation.pl` + Kaldi utility + + Parameters + ---------- + dictionaries: list[str] + List of dictionary names + vad_paths: dict[str, str] + Dictionary of VAD archives per dictionary name + segmentation_options: dict[str, Any] + Options for segmentation + """ + + utterances = {} + + speaker = Speaker("speech") + for dict_name in dictionaries: + vad_path = vad_paths[dict_name] + + vad = load_scp(vad_path, data_type=int) + for recording, frames in vad.items(): + file = File(recording) + initial_segments = get_initial_segmentation( + frames, segmentation_options["frame_shift"] + ) + merged = merge_segments( + initial_segments, + segmentation_options["min_pause_duration"], + segmentation_options["max_segment_length"], + segmentation_options["snap_boundary_threshold"], + ) + for seg in merged: + utterances[recording] = Utterance( + speaker, file, begin=seg["begin"], end=seg["end"], text="speech" + ) + return utterances + + +class Segmenter(VadConfigMixin, AcousticCorpusMixin, FileExporterMixin, TopLevelMfaWorker): """ Class for performing speaker classification Parameters ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - segmentation_config : :class:`~montreal_forced_aligner.config.SegmentationConfig` - Configuration for alignment - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - debug : bool - Flag for running in debug mode, defaults to false - verbose : bool - Flag for running in verbose mode, defaults to false - logger : :class:`~logging.Logger`, optional - Logger to use + max_segment_length : float + Maximum duration of segments + min_pause_duration : float + Minimum duration of pauses + snap_boundary_threshold : float + Threshold for snapping segment boundaries to each other """ def __init__( self, - corpus: Corpus, - segmentation_config: SegmentationConfig, - temp_directory: Optional[str] = None, - debug: Optional[bool] = False, - verbose: Optional[bool] = False, - logger: Optional[Logger] = None, + max_segment_length: float = 30, + min_pause_duration: float = 0.05, + snap_boundary_threshold: float = 0.15, + **kwargs ): - self.corpus = corpus - self.segmentation_config = segmentation_config - - if not temp_directory: - temp_directory = TEMP_DIR - self.temp_directory = temp_directory - self.debug = debug - self.verbose = verbose - self.logger = logger - self.uses_cmvn = False - self.uses_slices = False - self.uses_vad = False - self.speaker_independent = True - self.setup() + super().__init__(**kwargs) + self.max_segment_length = max_segment_length + self.min_pause_duration = min_pause_duration + self.snap_boundary_threshold = snap_boundary_threshold - @property - def segmenter_directory(self) -> str: - """Temporary directory for segmentation""" - return os.path.join(self.temp_directory, "segmentation") + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse parameters for segmentation from a config path or command-line arguments + + Parameters + ---------- + config_path: str + Config path + args: :class:`~argparse.Namespace` + Command-line arguments from argparse + unknown_args: list[str], optional + Extra command-line arguments + + Returns + ------- + dict[str, Any] + Configuration parameters + """ + global_params = {} + if config_path and os.path.exists(config_path): + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + for k, v in data.items(): + if k == "features": + if "type" in v: + v["feature_type"] = v["type"] + del v["type"] + global_params.update(v) + else: + global_params[k] = v + global_params.update(cls.parse_args(args, unknown_args)) + return global_params + + def segment_vad_arguments(self) -> list[SegmentVadArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.segmenter.segment_vad_func` + + Returns + ------- + list[SegmentVadArguments] + Arguments for processing + """ + return [ + SegmentVadArguments( + j.current_dictionary_names, + j.construct_path_dictionary(self.split_directory, "vad", "scp"), + self.segmentation_options, + ) + for j in self.jobs + ] @property - def vad_options(self) -> MetaDict: - """Options for performing VAD""" + def segmentation_options(self): + """Options for segmentation""" return { - "energy_threshold": self.segmentation_config.energy_threshold, - "energy_mean_scale": self.segmentation_config.energy_mean_scale, + "max_segment_length": self.max_segment_length, + "min_pause_duration": self.min_pause_duration, + "snap_boundary_threshold": self.snap_boundary_threshold, + "frame_shift": round(self.frame_shift / 1000, 2), } @property - def use_mp(self) -> bool: - """Flag for whether to use multiprocessing""" - return self.segmentation_config.use_mp + def workflow_identifier(self) -> str: + """Segmentation workflow""" + return "segmentation" - def setup(self) -> None: + def segment_vad(self) -> None: """ - Sets up the corpus and segmenter for performing VAD + Run segmentation based off of VAD. - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries + See Also + -------- + segment_vad_func + Multiprocessing helper function for each job + segment_vad_arguments + Job method for generating arguments for helper function """ - done_path = os.path.join(self.segmenter_directory, "done") - if os.path.exists(done_path): - self.logger.info("Classification already done, skipping initialization.") - return - dirty_path = os.path.join(self.segmenter_directory, "dirty") - if os.path.exists(dirty_path): - shutil.rmtree(self.segmenter_directory) - log_dir = os.path.join(self.segmenter_directory, "log") + + jobs = self.segment_vad_arguments() + if self.use_mp: + segment_info = run_mp(segment_vad_func, jobs, self.features_log_directory, True) + else: + segment_info = run_non_mp(segment_vad_func, jobs, self.features_log_directory, True) + for j in self.jobs: + for old_utt, utterance in segment_info[j.name].items(): + old_utt = self.utterances[old_utt] + file = old_utt.file + if self.ignore_speakers: + if utterance.speaker_name not in self.speakers: + self.speakers[utterance.speaker_name] = Speaker(utterance.speaker_name) + speaker = self.speakers[utterance.speaker_name] + else: + speaker = old_utt.speaker + utterance.file = file + utterance.set_speaker(speaker) + self.add_utterance(utterance) + utterance_ids = [x.name for x in self.utterances.values() if x.begin is None] + for u in utterance_ids: + self.delete_utterance(u) + + def setup(self) -> None: + """Setup segmentation""" + self.check_previous_run() + log_dir = os.path.join(self.working_directory, "log") os.makedirs(log_dir, exist_ok=True) try: - self.corpus.initialize_corpus(None, self.segmentation_config.feature_config) + self.load_corpus() except Exception as e: - with open(dirty_path, "w"): - pass if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) + e.update_log_file(self.logger) raise def segment(self) -> None: @@ -130,28 +329,31 @@ def segment(self) -> None: :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ - log_directory = os.path.join(self.segmenter_directory, "log") - dirty_path = os.path.join(self.segmenter_directory, "dirty") - done_path = os.path.join(self.segmenter_directory, "done") + self.setup() + log_directory = os.path.join(self.working_directory, "log") + done_path = os.path.join(self.working_directory, "done") if os.path.exists(done_path): self.logger.info("Classification already done, skipping.") return try: - self.corpus.compute_vad() + self.compute_vad() self.uses_vad = True - segment_vad(self) + self.segment_vad() parse_logs(log_directory) except Exception as e: - with open(dirty_path, "w"): - pass if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) + e.update_log_file(self.logger) raise with open(done_path, "w"): pass - def export_segments(self, output_directory: str) -> None: + @property + def backup_output_directory(self) -> str: + """Backup output directory""" + return os.path.join(self.workflow_directory, "backup") + + def export_files(self, output_directory: str) -> None: """ Export the results of segmentation as TextGrids @@ -161,8 +363,8 @@ def export_segments(self, output_directory: str) -> None: Directory to save segmentation TextGrids """ backup_output_directory = None - if not self.segmentation_config.overwrite: - backup_output_directory = os.path.join(self.segmenter_directory, "transcriptions") + if not self.overwrite: + backup_output_directory = os.path.join(self.working_directory, "transcriptions") os.makedirs(backup_output_directory, exist_ok=True) - for f in self.corpus.files.values(): + for f in self.files.values(): f.save(output_directory, backup_output_directory) diff --git a/montreal_forced_aligner/speaker_classifier.py b/montreal_forced_aligner/speaker_classifier.py index c2ba62e2..31d8aef0 100644 --- a/montreal_forced_aligner/speaker_classifier.py +++ b/montreal_forced_aligner/speaker_classifier.py @@ -6,132 +6,104 @@ """ from __future__ import annotations -import logging import os -import shutil from typing import TYPE_CHECKING, Optional import numpy as np +import yaml +from sklearn.cluster import KMeans -from .abc import MetaDict -from .config import TEMP_DIR +from .abc import FileExporterMixin, TopLevelMfaWorker from .corpus.classes import Speaker +from .corpus.ivector_corpus import IvectorCorpusMixin from .exceptions import KaldiProcessingError from .helper import load_scp -from .multiprocessing import extract_ivectors +from .models import IvectorExtractorModel from .utils import log_kaldi_errors if TYPE_CHECKING: - from .config import SpeakerClassificationConfig - from .corpus import Corpus - from .models import IvectorExtractorModel + from argparse import Namespace + from .abc import MetaDict __all__ = ["SpeakerClassifier"] -class SpeakerClassifier: +class SpeakerClassifier(IvectorCorpusMixin, TopLevelMfaWorker, FileExporterMixin): """ - Class for performing speaker classification + Class for performing speaker classification, not currently very functional, but + is planned to be expanded in the future Parameters ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - ivector_extractor : :class:`~montreal_forced_aligner.models.IvectorExtractorModel` - Configuration for alignment - classification_config : :class:`~montreal_forced_aligner.config.SpeakerClassificationConfig` - Configuration for alignment - compute_segments: bool, optional - Flag for whether segments should be created + ivector_extractor_path : str + Path to ivector extractor model num_speakers: int, optional Number of speakers in the corpus, if known cluster: bool, optional Flag for whether speakers should be clustered instead of classified - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - call_back : callable, optional - Specifies a call back function for diarization - debug : bool - Flag for running in debug mode, defaults to false - verbose : bool - Flag for running in verbose mode, defaults to false """ def __init__( - self, - corpus: Corpus, - ivector_extractor: IvectorExtractorModel, - classification_config: SpeakerClassificationConfig, - compute_segments: Optional[bool] = False, - num_speakers: Optional[int] = None, - cluster: Optional[bool] = False, - temp_directory: Optional[str] = None, - debug: Optional[bool] = False, - verbose: Optional[bool] = False, - logger: Optional[logging.Logger] = None, + self, ivector_extractor_path: str, num_speakers: int = 0, cluster: bool = True, **kwargs ): - self.corpus = corpus - self.ivector_extractor = ivector_extractor - self.feature_config = self.ivector_extractor.feature_config - self.classification_config = classification_config - - if not temp_directory: - temp_directory = TEMP_DIR - self.temp_directory = temp_directory - os.makedirs(self.temp_directory, exist_ok=True) - self.debug = debug - self.compute_segments = compute_segments - self.verbose = verbose - if logger is None: - self.log_file = os.path.join(self.temp_directory, "speaker_classifier.log") - self.logger = logging.getLogger("speaker_classifier") - self.logger.setLevel(logging.INFO) - handler = logging.FileHandler(self.log_file, "w", "utf-8") - handler.setFormatter = logging.Formatter("%(name)s %(message)s") - self.logger.addHandler(handler) - else: - self.logger = logger + self.ivector_extractor = IvectorExtractorModel(ivector_extractor_path) + kwargs.update(self.ivector_extractor.parameters) + super().__init__(**kwargs) self.classifier = None self.speaker_labels = {} self.ivectors = {} self.num_speakers = num_speakers self.cluster = cluster - self.uses_voiced = False - self.uses_cmvn = True - self.uses_splices = False - self.setup() - @property - def classify_directory(self) -> str: - """Temporary directory for speaker classification""" - return os.path.join(self.temp_directory, "speaker_classification") + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse parameters for speaker classification from a config path or command-line arguments - @property - def data_directory(self) -> str: - """Corpus data directory""" - return self.corpus.split_directory + Parameters + ---------- + config_path: str + Config path + args: :class:`~argparse.Namespace` + Command-line arguments from argparse + unknown_args: list[str], optional + Extra command-line arguments + + Returns + ------- + dict[str, Any] + Configuration parameters + """ + global_params = {} + if config_path and os.path.exists(config_path): + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + for k, v in data.items(): + if k == "features": + if "type" in v: + v["feature_type"] = v["type"] + del v["type"] + global_params.update(v) + else: + global_params[k] = v + global_params.update(cls.parse_args(args, unknown_args)) + return global_params @property - def working_directory(self) -> str: - """Current working directory for the speaker classifier""" - return self.classify_directory + def workflow_identifier(self) -> str: + """Speaker classification identifier""" + return "speaker_classification" @property def ie_path(self) -> str: - """Path for the IvectorExtractor model file""" + """Path for the ivector extractor model file""" return os.path.join(self.working_directory, "final.ie") - @property - def speaker_classification_model_path(self) -> str: - """Path for the speaker classification model""" - return os.path.join(self.working_directory, "speaker_classifier.mdl") - - @property - def speaker_labels_path(self) -> str: - """Path for the speaker labels file""" - return os.path.join(self.working_directory, "speaker_labels.txt") - @property def model_path(self) -> str: """Path for the acoustic model file""" @@ -142,37 +114,6 @@ def dubm_path(self) -> str: """Path for the DUBM model""" return os.path.join(self.working_directory, "final.dubm") - @property - def working_log_directory(self) -> str: - """Current log directory""" - return self.log_directory - - @property - def ivector_options(self) -> MetaDict: - """Ivector configuration options""" - data = self.ivector_extractor.meta - data["silence_weight"] = 0.0 - data["posterior_scale"] = 0.1 - data["max_count"] = 100 - data["sil_phones"] = None - return data - - @property - def log_directory(self) -> str: - """Log directory""" - return os.path.join(self.classify_directory, "log") - - @property - def use_mp(self) -> bool: - """Flag for whether to use multiprocessing""" - return self.classification_config.use_mp - - def extract_ivectors(self) -> None: - """ - Extract ivectors for the corpus - """ - extract_ivectors(self) - def setup(self) -> None: """ Sets up the corpus and speaker classifier @@ -182,25 +123,22 @@ def setup(self) -> None: :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ - done_path = os.path.join(self.classify_directory, "done") + + self.check_previous_run() + done_path = os.path.join(self.working_directory, "done") if os.path.exists(done_path): self.logger.info("Classification already done, skipping initialization.") return - dirty_path = os.path.join(self.classify_directory, "dirty") - if os.path.exists(dirty_path): - shutil.rmtree(self.classify_directory) - log_dir = os.path.join(self.classify_directory, "log") + log_dir = os.path.join(self.working_directory, "log") os.makedirs(log_dir, exist_ok=True) - self.ivector_extractor.export_model(self.classify_directory) try: - self.corpus.initialize_corpus(None, self.feature_config) + self.load_corpus() + self.ivector_extractor.export_model(self.working_directory) self.extract_ivectors() except Exception as e: - with open(dirty_path, "w"): - pass if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) + e.update_log_file(self.logger) raise def load_ivectors(self) -> None: @@ -208,8 +146,7 @@ def load_ivectors(self) -> None: Load ivectors from the temporary directory """ self.ivectors = {} - for j in self.corpus.jobs: - ivectors_args = j.extract_ivector_arguments(self) + for ivectors_args in self.extract_ivectors_arguments(): for ivectors_path in ivectors_args.ivector_paths.values(): ivec = load_scp(ivectors_path) for utt, ivector in ivec.items(): @@ -220,7 +157,7 @@ def cluster_utterances(self) -> None: """ Cluster utterances based on their ivectors """ - from sklearn.cluster import KMeans + self.setup() if not self.ivectors: self.load_ivectors() @@ -228,16 +165,17 @@ def cluster_utterances(self) -> None: for v in self.ivectors.values(): x.append(v) x = np.array(x) - clust = KMeans(self.num_speakers).fit(x) - y = clust.labels_ + km = KMeans(self.num_speakers, max_iter=100) + km.fit(x) + y = km.labels_ for i, u in enumerate(self.ivectors.keys()): speaker_name = y[i] - utterance = self.corpus.utterances[u] - if speaker_name not in self.corpus.speakers: - self.corpus.speakers[speaker_name] = Speaker(speaker_name) - utterance.set_speaker(self.corpus.speakers[speaker_name]) + utterance = self.utterances[u] + if speaker_name not in self.speakers: + self.speakers[speaker_name] = Speaker(speaker_name) + utterance.set_speaker(self.speakers[speaker_name]) - def export_classification(self, output_directory: str) -> None: + def export_files(self, output_directory: str) -> None: """ Export files with their new speaker labels @@ -247,9 +185,9 @@ def export_classification(self, output_directory: str) -> None: Output directory to save files """ backup_output_directory = None - if not self.classification_config.overwrite: - backup_output_directory = os.path.join(self.classify_directory, "output") + if not self.overwrite: + backup_output_directory = os.path.join(self.working_directory, "output") os.makedirs(backup_output_directory, exist_ok=True) - for file in self.corpus.files.values(): + for file in self.files.values(): file.save(output_directory, backup_output_directory) diff --git a/montreal_forced_aligner/textgrid.py b/montreal_forced_aligner/textgrid.py index 506cf0c2..ab286607 100644 --- a/montreal_forced_aligner/textgrid.py +++ b/montreal_forced_aligner/textgrid.py @@ -6,19 +6,17 @@ from __future__ import annotations import os -import sys -import traceback -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Optional from praatio import textgrid as tgio -from .abc import Aligner from .data import CtmInterval if TYPE_CHECKING: - from .corpus.classes import DictionaryData, File, Speaker - from .dictionary import ReversedMappingType - from .multiprocessing.alignment import CtmType + from .abc import ReversedMappingType + from .alignment.base import CorpusAligner + from .corpus.classes import File, Speaker + from .dictionary import DictionaryData __all__ = [ "process_ctm_line", @@ -29,7 +27,6 @@ "export_textgrid", "ctm_to_textgrid", "output_textgrid_writing_errors", - "ctms_to_textgrids_non_mp", ] @@ -49,31 +46,37 @@ def process_ctm_line(line: str) -> CtmInterval: """ line = line.split(" ") utt = line[0] - begin = round(float(line[2]), 4) - duration = float(line[3]) - end = round(begin + duration, 4) - label = line[4] + if len(line) == 5: + begin = round(float(line[2]), 4) + duration = float(line[3]) + end = round(begin + duration, 4) + label = line[4] + else: + begin = round(float(line[1]), 4) + duration = float(line[2]) + end = round(begin + duration, 4) + label = line[3] return CtmInterval(begin, end, label, utt) def parse_from_word( - ctm_labels: List[CtmInterval], text: List[str], dictionary_data: DictionaryData -) -> List[CtmInterval]: + ctm_labels: list[CtmInterval], text: list[str], dictionary_data: DictionaryData +) -> list[CtmInterval]: """ Parse CTM intervals into the corresponding text for an utterance Parameters ---------- - ctm_labels: List[:class:`~montreal_forced_aligner.data.CtmInterval`] + ctm_labels: list[:class:`~montreal_forced_aligner.data.CtmInterval`] CTM intervals - text: List[str] + text: list[str] The original text that was to be aligned - dictionary_data: DictionaryData + dictionary_data: :class:`~montreal_forced_aligner.dictionary.DictionaryData` Dictionary data necessary for splitting subwords Returns ------- - List[:class:`~montreal_forced_aligner.data.CtmInterval`] + list[:class:`~montreal_forced_aligner.data.CtmInterval`] Correct intervals with subwords merged back into their original text """ cur_ind = 0 @@ -99,22 +102,22 @@ def parse_from_word( def parse_from_word_no_cleanup( - ctm_labels: List[CtmInterval], reversed_word_mapping: ReversedMappingType -) -> List[CtmInterval]: + ctm_labels: list[CtmInterval], reversed_word_mapping: ReversedMappingType +) -> list[CtmInterval]: """ Assume that subwords in the CTM files are desired, so just does a reverse look up to get the sub word text Parameters ---------- - ctm_labels: List[:class:`~montreal_forced_aligner.data.CtmInterval`] + ctm_labels: list[:class:`~montreal_forced_aligner.data.CtmInterval`] List of :class:`~montreal_forced_aligner.data.CtmInterval` to convert - reversed_word_mapping: Dict[int, str] + reversed_word_mapping: dict[int, str] Look up for Kaldi word IDs to convert them back to text Returns ------- - List[:class:`~montreal_forced_aligner.data.CtmInterval`] + list[:class:`~montreal_forced_aligner.data.CtmInterval`] Parsed intervals with text rather than integer IDs """ for ctm_interval in ctm_labels: @@ -124,25 +127,25 @@ def parse_from_word_no_cleanup( def parse_from_phone( - ctm_labels: List[CtmInterval], + ctm_labels: list[CtmInterval], reversed_phone_mapping: ReversedMappingType, - positions: List[str], -) -> List[CtmInterval]: + positions: list[str], +) -> list[CtmInterval]: """ Parse CtmIntervals to original phone transcriptions Parameters ---------- - ctm_labels: List[:class:`~montreal_forced_aligner.data.CtmInterval`] + ctm_labels: list[:class:`~montreal_forced_aligner.data.CtmInterval`] List of :class:`~montreal_forced_aligner.data.CtmInterval` to convert - reversed_phone_mapping: Dict[int, str] + reversed_phone_mapping: dict[int, str] Mapping to convert phone IDs to phone labels - positions: List[str] + positions: list[str] List of word positions to account for Returns ------- - List[:class:`~montreal_forced_aligner.data.CtmInterval`] + list[:class:`~montreal_forced_aligner.data.CtmInterval`] Parsed intervals with phone labels rather than IDs """ for ctm_interval in ctm_labels: @@ -154,120 +157,7 @@ def parse_from_phone( return ctm_labels -def ctms_to_textgrids_non_mp(aligner: Aligner) -> None: - """ - Parse CTM files to TextGrids without using multiprocessing - - Parameters - ---------- - aligner: :class:`~montreal_forced_aligner.aligner.base.BaseAligner` - Aligner that generated the CTM files - """ - - def process_current_word_labels(): - """Process the current stack of word labels""" - speaker = cur_utt.speaker - - text = cur_utt.text.split() - if aligner.align_config.cleanup_textgrids: - actual_labels = parse_from_word(current_labels, text, speaker.dictionary_data) - else: - actual_labels = parse_from_word_no_cleanup( - current_labels, speaker.dictionary_data.reversed_words_mapping - ) - cur_utt.word_labels = actual_labels - - def process_current_phone_labels(): - """Process the current stack of phone labels""" - speaker = cur_utt.speaker - - cur_utt.phone_labels = parse_from_phone( - current_labels, speaker.dictionary.reversed_phone_mapping, speaker.dictionary.positions - ) - - export_errors = {} - for j in aligner.corpus.jobs: - - word_arguments = j.cleanup_word_ctm_arguments(aligner) - phone_arguments = j.phone_ctm_arguments(aligner) - aligner.logger.debug(f"Parsing ctms for job {j.name}...") - cur_utt = None - current_labels = [] - for dict_name in word_arguments.dictionaries: - with open(word_arguments.ctm_paths[dict_name], "r") as f: - for line in f: - line = line.strip() - if line == "": - continue - ctm_interval = process_ctm_line(line) - utt = aligner.corpus.utterances[ctm_interval.utterance] - if cur_utt is None: - cur_utt = utt - if utt.is_segment: - utt_begin = utt.begin - else: - utt_begin = 0 - if utt != cur_utt: - process_current_word_labels() - cur_utt = utt - current_labels = [] - - ctm_interval.shift_times(utt_begin) - current_labels.append(ctm_interval) - if current_labels: - process_current_word_labels() - cur_utt = None - current_labels = [] - for dict_name in phone_arguments.dictionaries: - with open(phone_arguments.ctm_paths[dict_name], "r") as f: - for line in f: - line = line.strip() - if line == "": - continue - ctm_interval = process_ctm_line(line) - utt = aligner.corpus.utterances[ctm_interval.utterance] - if cur_utt is None: - cur_utt = utt - if utt.is_segment: - utt_begin = utt.begin - else: - utt_begin = 0 - if utt != cur_utt and cur_utt is not None: - process_current_phone_labels() - cur_utt = utt - current_labels = [] - - ctm_interval.shift_times(utt_begin) - current_labels.append(ctm_interval) - if current_labels: - process_current_phone_labels() - - aligner.logger.debug(f"Generating TextGrids for job {j.name}...") - processed_files = set() - for file in j.job_files().values(): - first_file_write = True - if file.name in processed_files: - first_file_write = False - try: - ctm_to_textgrid(file, aligner, first_file_write) - processed_files.add(file.name) - except Exception: - if aligner.align_config.debug: - raise - exc_type, exc_value, exc_traceback = sys.exc_info() - export_errors[file.name] = "\n".join( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - if export_errors: - aligner.logger.warning( - f"There were {len(export_errors)} errors encountered in generating TextGrids. " - f"Check the output_errors.txt file in {os.path.join(aligner.textgrid_output)} " - f"for more details" - ) - output_textgrid_writing_errors(aligner.textgrid_output, export_errors) - - -def output_textgrid_writing_errors(output_directory: str, export_errors: Dict[str, str]) -> None: +def output_textgrid_writing_errors(output_directory: str, export_errors: dict[str, str]) -> None: """ Output any errors that were encountered in writing TextGrids @@ -275,7 +165,7 @@ def output_textgrid_writing_errors(output_directory: str, export_errors: Dict[st ---------- output_directory: str Directory to save TextGrids files - export_errors: Dict[str, str] + export_errors: dict[str, str] Dictionary of errors encountered """ error_log = os.path.join(output_directory, "output_errors.txt") @@ -294,7 +184,7 @@ def output_textgrid_writing_errors(output_directory: str, export_errors: Dict[st def generate_tiers( file: File, cleanup_textgrids: Optional[bool] = True -) -> Dict[Speaker, Dict[str, CtmType]]: +) -> dict[Speaker, dict[str, list[CtmInterval]]]: """ Generate TextGrid tiers for a given File @@ -307,7 +197,7 @@ def generate_tiers( Returns ------- - Dict[Speaker, Dict[str, CtmType]] + dict[Speaker, dict[str, list[:class:`~montreal_forced_aligner.data.CtmInterval`]] Tier information per speaker, with :class:`~montreal_forced_aligner.data.CtmInterval` split by "phones" and "words" """ output = {} @@ -320,7 +210,7 @@ def generate_tiers( words = [] phones = [] - if dictionary_data.dictionary_config.multilingual_ipa and cleanup_textgrids: + if dictionary_data.multilingual_ipa and cleanup_textgrids: phone_ind = 0 for interval in u.word_labels: end = interval.end @@ -329,16 +219,14 @@ def generate_tiers( word, ) subwords = [ - x - if x in dictionary_data.words_mapping - else dictionary_data.dictionary_config.oov_word + x if x in dictionary_data.words_mapping else dictionary_data.oov_word for x in subwords ] subprons = [dictionary_data.words[x] for x in subwords] cur_phones = [] while u.phone_labels[phone_ind].end <= end: p = u.phone_labels[phone_ind] - if p.label in dictionary_data.dictionary_config.silence_phones: + if p.label in dictionary_data.silence_phones: phone_ind += 1 continue cur_phones.append(p) @@ -354,10 +242,7 @@ def generate_tiers( for interval in u.word_labels: words.append(interval) for interval in u.phone_labels: - if ( - interval.label in dictionary_data.dictionary_config.silence_phones - and cleanup_textgrids - ): + if interval.label in dictionary_data.silence_phones and cleanup_textgrids: continue phones.append(interval) if speaker not in output: @@ -371,7 +256,7 @@ def generate_tiers( def export_textgrid( file: File, output_path: str, - speaker_data: Dict[Speaker, Dict[str, CtmType]], + speaker_data: dict[Speaker, dict[str, list[CtmInterval]]], frame_shift: int, first_file_write: Optional[bool] = True, ) -> None: @@ -384,7 +269,7 @@ def export_textgrid( File object to export output_path: str Output path of the file - speaker_data: Dict[Speaker, Dict[str, List[:class:`~montreal_forced_aligner.data.CtmInterval`]] + speaker_data: dict[Speaker, dict[str, list[:class:`~montreal_forced_aligner.data.CtmInterval`]] Per speaker, per word/phone :class:`~montreal_forced_aligner.data.CtmInterval` frame_shift: int Frame shift of features, in ms @@ -454,7 +339,7 @@ def export_textgrid( tg.save(output_path, includeBlankSpaces=True, format="long_textgrid", reportingMode="error") -def ctm_to_textgrid(file: File, aligner: Aligner, first_file_write=True) -> None: +def ctm_to_textgrid(file: File, aligner: CorpusAligner, first_file_write=True) -> None: """ Export a File to TextGrid @@ -462,18 +347,16 @@ def ctm_to_textgrid(file: File, aligner: Aligner, first_file_write=True) -> None ---------- file: File File to export - aligner: :class:`~montreal_forced_aligner.aligner.base.BaseAligner` or :class:`~montreal_forced_aligner.trainers.BaseTrainer` + aligner: CorpusAligner Aligner used to generate the alignments first_file_write: bool, optional Flag for whether this is the first time touching this file """ - data = generate_tiers(file, cleanup_textgrids=aligner.align_config.cleanup_textgrids) + data = generate_tiers(file, cleanup_textgrids=aligner.cleanup_textgrids) backup_output_directory = None - if not aligner.align_config.overwrite: + if not aligner.overwrite: backup_output_directory = aligner.backup_output_directory os.makedirs(backup_output_directory, exist_ok=True) output_path = file.construct_output_path(aligner.textgrid_output, backup_output_directory) - export_textgrid( - file, output_path, data, aligner.align_config.feature_config.frame_shift, first_file_write - ) + export_textgrid(file, output_path, data, aligner.frame_shift, first_file_write) diff --git a/montreal_forced_aligner/trainers/__init__.py b/montreal_forced_aligner/trainers/__init__.py deleted file mode 100644 index 8a03609b..00000000 --- a/montreal_forced_aligner/trainers/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Training acoustic models -======================== - - -""" -from .base import BaseTrainer # noqa -from .ivector_extractor import IvectorExtractorTrainer # noqa -from .lda import LdaTrainer # noqa -from .monophone import MonophoneTrainer # noqa -from .sat import SatTrainer # noqa -from .triphone import TriphoneTrainer # noqa - -__all__ = [ - "BaseTrainer", - "IvectorExtractorTrainer", - "LdaTrainer", - "MonophoneTrainer", - "SatTrainer", - "TriphoneTrainer", - "base", - "ivector_extractor", - "lda", - "monophone", - "sat", - "triphone", -] - -BaseTrainer.__module__ = "montreal_forced_aligner.trainers" -IvectorExtractorTrainer.__module__ = "montreal_forced_aligner.trainers" -LdaTrainer.__module__ = "montreal_forced_aligner.trainers" -MonophoneTrainer.__module__ = "montreal_forced_aligner.trainers" -SatTrainer.__module__ = "montreal_forced_aligner.trainers" -TriphoneTrainer.__module__ = "montreal_forced_aligner.trainers" diff --git a/montreal_forced_aligner/trainers/base.py b/montreal_forced_aligner/trainers/base.py deleted file mode 100644 index 6ba82361..00000000 --- a/montreal_forced_aligner/trainers/base.py +++ /dev/null @@ -1,594 +0,0 @@ -"""Class definition for BaseTrainer""" -from __future__ import annotations - -import os -import re -import shutil -import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from tqdm import tqdm - -from ..abc import Aligner, MetaDict, Trainer -from ..config import FeatureConfig -from ..exceptions import KaldiProcessingError, TrainerError -from ..models import AcousticModel -from ..multiprocessing.alignment import ( - acc_stats, - align, - compile_information, - compile_train_graphs, - compute_alignment_improvement, - convert_ali_to_textgrids, -) -from ..utils import log_kaldi_errors, parse_logs - -if TYPE_CHECKING: - from ..corpus import Corpus - from ..dictionary import MultispeakerDictionary - - -__all__ = ["BaseTrainer"] - - -class BaseTrainer(Aligner, Trainer): - """ - Base trainer class for training acoustic models and ivector extractors - - Parameters - ---------- - default_feature_config: :class:`~montreal_forced_aligner.config.FeatureConfig` - Default feature config - - Attributes - ---------- - feature_config : :class:`~montreal_forced_aligner.config.FeatureConfig` - Feature configuration - num_iterations : int - Number of training iterations to perform, defaults to 40 - transition_scale : float - Scaling of transition costs in alignment, defaults to 1.0 - acoustic_scale : float - Scaling of acoustic costs in alignment, defaults to 0.1 - self_loop_scale : float - Scaling of self loop costs in alignment, defaults to 0.1 - beam : int - Default beam width for alignment, defaults = 10 - retry_beam : int - Beam width to fall back on if no alignment is produced, defaults to 40 - max_gaussians : int - Total number of gaussians, defaults to 1000 - boost_silence : float - Factor by which to boost silence likelihoods in alignment, defaults to 1.0 - realignment_iterations : list - List of iterations to perform alignment - power : float - Exponent for number of gaussians according to occurrence counts, defaults to 0.25 - debug: bool - Flag for debug mode - use_mp: bool - Flag for whether to use multiprocessing - iteration: int - Current iteration - training_complete: bool - Flag for whether training has been successfully completed - speaker_independent: bool - Flag for using speaker-independent features regardless of speaker adaptation - uses_cmvn: bool - Flag for whether to include CMVN in features - uses_splices: bool - Flag for whether to include splices in features - uses_voiced: bool - Flag for whether to use voiced features - """ - - def __init__(self, default_feature_config: FeatureConfig): - self.logger = None - self.dictionary: Optional[MultispeakerDictionary] = None - self.transition_scale = 1.0 - self.acoustic_scale = 0.1 - self.self_loop_scale = 0.1 - self.realignment_iterations = [] - self.num_iterations = 40 - self.beam = 10 - self.retry_beam = 40 - self.max_gaussians = 1000 - self.boost_silence = 1.0 - self.power = 0.25 - self.subset = None - self.calc_pron_probs = False - self.architecture = "gmm-hmm" - self.feature_config = FeatureConfig() - self.feature_config.update(default_feature_config.params()) - self.initial_gaussians = None # Gets set later - self.temp_directory = None - self.identifier = None - self.corpus: Optional[Corpus] = None - self.data_directory = None - self.debug = False - self.use_mp = True - self.current_gaussians = None - self.iteration = 0 - self.training_complete = False - self.speaker_independent = True - self.uses_cmvn = True - self.uses_splices = False - self.uses_voiced = False - self.previous_trainer: Optional[BaseTrainer] = None - - @property - def train_directory(self) -> str: - """Training directory""" - return os.path.join(self.temp_directory, self.identifier) - - @property - def log_directory(self) -> str: - """Training log directory""" - return os.path.join(self.train_directory, "log") - - @property - def align_directory(self) -> str: - """Alignment directory""" - return os.path.join(self.temp_directory, f"{self.identifier}_ali") - - @property - def align_log_directory(self) -> str: - """Alignment log directory""" - return os.path.join(self.align_directory, "log") - - @property - def working_directory(self) -> str: - """Current working directory""" - if self.training_complete: - return self.align_directory - return self.train_directory - - @property - def working_log_directory(self) -> str: - """Log directory of current working directory""" - if self.training_complete: - return self.align_log_directory - return self.log_directory - - @property - def fmllr_options(self) -> MetaDict: - """Options for fMLLR calculation, only used by SatTrainer""" - raise NotImplementedError - - @property - def lda_options(self) -> MetaDict: - """Options for LDA calculation, only used by LdaTrainer""" - raise NotImplementedError - - @property - def tree_path(self): - """Path to tree file""" - return os.path.join(self.working_directory, "tree") - - @property - def current_model_path(self): - """Current acoustic model path""" - if ( - self.training_complete - or self.iteration is None - or self.iteration > self.num_iterations - ): - return os.path.join(self.working_directory, "final.mdl") - return os.path.join(self.working_directory, f"{self.iteration}.mdl") - - @property - def model_path(self) -> str: - """Current acoustic model path""" - return self.current_model_path - - @property - def next_model_path(self): - """Next iteration's acoustic model path""" - if self.iteration > self.num_iterations: - return os.path.join(self.working_directory, "final.mdl") - return os.path.join(self.working_directory, f"{self.iteration + 1}.mdl") - - @property - def next_occs_path(self): - """Next iteration's occs file path""" - if self.training_complete: - return os.path.join(self.working_directory, "final.occs") - return os.path.join(self.working_directory, f"{self.iteration + 1}.occs") - - @property - def alignment_model_path(self): - """Alignment model path""" - path = os.path.join(self.working_directory, "final.alimdl") - if self.speaker_independent and os.path.exists(path): - return path - if not self.training_complete: - return self.current_model_path - return os.path.join(self.working_directory, "final.mdl") - - def compute_calculated_properties(self) -> None: - """Compute any calculated properties such as alignment iterations""" - pass - - @property - def train_type(self) -> str: - """Training type, not implemented for BaseTrainer""" - raise NotImplementedError - - @property - def phone_type(self) -> str: - """Phone type, not implemented for BaseTrainer""" - raise NotImplementedError - - @property - def final_gaussian_iteration(self) -> int: - """Final iteration to increase gaussians""" - return self.num_iterations - 10 - - @property - def gaussian_increment(self) -> int: - """Amount by which gaussians should be increases each iteration""" - return int((self.max_gaussians - self.initial_gaussians) / self.final_gaussian_iteration) - - @property - def align_options(self) -> MetaDict: - """Options for alignment""" - options_silence_csl = "" - if self.dictionary: - options_silence_csl = self.dictionary.config.optional_silence_csl - return { - "beam": self.beam, - "retry_beam": self.retry_beam, - "transition_scale": self.transition_scale, - "acoustic_scale": self.acoustic_scale, - "self_loop_scale": self.self_loop_scale, - "boost_silence": self.boost_silence, - "debug": self.debug, - "optional_silence_csl": options_silence_csl, - } - - def analyze_align_stats(self) -> None: - """ - Analyzes alignment stats and outputs debug information - """ - unaligned, log_like = compile_information(self) - - self.logger.debug( - f"Average per frame likelihood (this might not actually mean anything) " - f"for {self.identifier}: {log_like}" - ) - self.logger.debug(f"Number of unaligned files " f"for {self.identifier}: {len(unaligned)}") - - def update(self, data: Dict[str, Any]) -> None: - """ - Update configuration data - - Parameters - ---------- - data: Dict[str, Any] - Data to update - """ - from ..config.base_config import PARSING_KEYS - - for k, v in data.items(): - if k == "use_mp": - self.feature_config.use_mp = v - if k == "features": - self.feature_config.update(v) - elif k in PARSING_KEYS: - continue - elif not hasattr(self, k): - raise TrainerError(f"No field found for key {k}") - else: - setattr(self, k, v) - self.compute_calculated_properties() - - def _setup_for_init( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: MultispeakerDictionary, - previous_trainer: Optional[BaseTrainer], - ) -> None: - """ - Default initialization for all Trainers - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - MultispeakerDictionary to use - previous_trainer: :class:`~montreal_forced_aligner.trainers.BaseTrainer`, optional - Previous trainer to initialize from - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - begin = time.time() - self.temp_directory = temporary_directory - self.identifier = identifier - dirty_path = os.path.join(self.train_directory, "dirty") - done_path = os.path.join(self.align_directory, "done") - if os.path.exists(dirty_path): # if there was an error, let's redo from scratch - shutil.rmtree(self.train_directory) - self.logger.info(f"Initializing training for {identifier}...") - self.corpus = corpus - try: - self.data_directory = self.corpus.split_directory - self.corpus.generate_features() - if self.subset is not None: - self.data_directory = self.corpus.subset_directory(self.subset) - except Exception as e: - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - self.dictionary = dictionary - self.previous_trainer = previous_trainer - if os.path.exists(done_path): - self.training_complete = True - self.iteration = None - return - os.makedirs(self.train_directory, exist_ok=True) - os.makedirs(self.log_directory, exist_ok=True) - if self.subset is not None and self.subset > corpus.num_utterances: - self.logger.warning( - "Subset specified is larger than the dataset, " - "using full corpus for this training block." - ) - - self.logger.debug(f"Setup for initialization took {time.time() - begin} seconds") - - def increment_gaussians(self): - """Increment the current number of gaussians""" - self.current_gaussians += self.gaussian_increment - - def init_training( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: MultispeakerDictionary, - previous_trainer: Optional[BaseTrainer], - ) -> None: - """ - Initialize training, not implemented for BaseTrainer - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - MultispeakerDictionary to use - previous_trainer: :class:`~montreal_forced_aligner.trainers.BaseTrainer`, optional - Previous trainer to initialize from - """ - raise NotImplementedError - - def get_unaligned_utterances(self) -> List[str]: - """Find all utterances that were not aligned for validation utility""" - error_regex = re.compile(r"Did not successfully decode file (\w+),") - error_files = [] - for j in self.corpus.jobs: - path = os.path.join(self.align_directory, "log", f"align.{j.name}.log") - if not os.path.exists(path): - continue - with open(path, "r") as f: - error_files.extend(error_regex.findall(f.read())) - return error_files - - def align(self, subset: Optional[int] = None) -> None: - """ - Align a subset of the corpus for the next trainer - - Parameters - ---------- - subset: int, optional - Number of utterances to include in the subset - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - if not os.path.exists(self.align_directory): - self.finalize_training() - dirty_path = os.path.join(self.align_directory, "dirty") - if os.path.exists(dirty_path): # if there was an error, let's redo from scratch - shutil.rmtree(self.align_directory) - done_path = os.path.join(self.align_directory, "done") - if not os.path.exists(done_path): - message = f"Generating alignments using {self.identifier} models" - if subset: - message += f" using {subset} utterances..." - else: - message += " for the whole corpus..." - self.logger.info(message) - begin = time.time() - if subset is None: - self.data_directory = self.corpus.split_directory - else: - self.data_directory = self.corpus.subset_directory(subset) - try: - self.iteration = None - compile_train_graphs(self) - align(self) - self.analyze_align_stats() - self.save(os.path.join(self.align_directory, "acoustic_model.zip")) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - with open(done_path, "w"): - pass - self.logger.debug(f"Alignment took {time.time() - begin} seconds") - else: - self.logger.info(f"Alignments using {self.identifier} models already done") - - def training_iteration(self): - """Perform an iteration of training""" - if os.path.exists(self.next_model_path): - self.iteration += 1 - return - if self.iteration in self.realignment_iterations: - align(self) - if self.debug: - compute_alignment_improvement(self) - acc_stats(self) - - parse_logs(self.log_directory) - if self.iteration < self.final_gaussian_iteration: - self.increment_gaussians() - self.iteration += 1 - - def train(self): - """ - Train the model - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - done_path = os.path.join(self.train_directory, "done") - dirty_path = os.path.join(self.train_directory, "dirty") - if os.path.exists(done_path): - self.logger.info(f"{self.identifier} training already done, skipping initialization.") - return - begin = time.time() - try: - with tqdm(total=self.num_iterations) as pbar: - while self.iteration < self.num_iterations + 1: - self.training_iteration() - pbar.update(1) - self.finalize_training() - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - with open(done_path, "w"): - pass - self.logger.info("Training complete!") - self.logger.debug(f"Training took {time.time() - begin} seconds") - - def finalize_training(self): - """ - Finalize the training, moving all relevant files from the training directory to the - alignment directory and changing flags to point at align directory as the working directory - - """ - os.makedirs(self.align_directory, exist_ok=True) - os.makedirs(self.align_log_directory, exist_ok=True) - shutil.copy( - os.path.join(self.train_directory, f"{self.num_iterations}.mdl"), - os.path.join(self.train_directory, "final.mdl"), - ) - shutil.copy( - os.path.join(self.train_directory, f"{self.num_iterations}.occs"), - os.path.join(self.train_directory, "final.occs"), - ) - shutil.copy(os.path.join(self.train_directory, "tree"), self.align_directory) - shutil.copyfile( - os.path.join(self.train_directory, "final.mdl"), - os.path.join(self.align_directory, "final.mdl"), - ) - - if os.path.exists(os.path.join(self.train_directory, "lda.mat")): - shutil.copyfile( - os.path.join(self.train_directory, "lda.mat"), - os.path.join(self.align_directory, "lda.mat"), - ) - shutil.copyfile( - os.path.join(self.train_directory, "final.occs"), - os.path.join(self.align_directory, "final.occs"), - ) - if not self.debug: - for i in range(1, self.num_iterations): - model_path = os.path.join(self.train_directory, f"{i}.mdl") - try: - os.remove(model_path) - except FileNotFoundError: - pass - try: - os.remove(os.path.join(self.train_directory, f"{i}.occs")) - except FileNotFoundError: - pass - self.training_complete = True - self.iteration = None - - @property - def meta(self) -> MetaDict: - """Generate metadata for the acoustic model that was trained""" - from datetime import datetime - - from ..utils import get_mfa_version - - data = { - "phones": sorted(self.dictionary.config.non_silence_phones), - "version": get_mfa_version(), - "architecture": self.architecture, - "train_date": str(datetime.now()), - "features": self.feature_config.params(), - "multilingual_ipa": self.dictionary.config.multilingual_ipa, - } - if self.dictionary.config.multilingual_ipa: - data["strip_diacritics"] = self.dictionary.config.strip_diacritics - data["digraphs"] = self.dictionary.config.digraphs - return data - - def export_textgrids(self) -> None: - """ - Export a TextGrid file for every sound file in the dataset - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - begin = time.time() - try: - convert_ali_to_textgrids(self) - except Exception as e: - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - self.logger.debug(f"Exporting textgrids took {time.time() - begin} seconds") - - def save(self, path: str, root_directory: Optional[str] = None) -> None: - """ - Export an acoustic model and dictionary to the specified path - - Parameters - ---------- - path : str - Path to save acoustic model and dictionary - root_directory : str or None - Path for root directory of temporary files - """ - directory, filename = os.path.split(path) - basename, _ = os.path.splitext(filename) - acoustic_model = AcousticModel.empty(basename, root_directory=root_directory) - acoustic_model.add_meta_file(self) - acoustic_model.add_model(self.train_directory) - if directory: - os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(path) - acoustic_model.dump(path) diff --git a/montreal_forced_aligner/trainers/ivector_extractor.py b/montreal_forced_aligner/trainers/ivector_extractor.py deleted file mode 100644 index 0c913979..00000000 --- a/montreal_forced_aligner/trainers/ivector_extractor.py +++ /dev/null @@ -1,438 +0,0 @@ -"""Class definition for IvectorExtractorTrainer""" -from __future__ import annotations - -import os -import shutil -import subprocess -import time -from typing import TYPE_CHECKING, Any, Dict, Optional - -from tqdm import tqdm - -from ..abc import IvectorExtractor, MetaDict -from ..exceptions import KaldiProcessingError -from ..helper import load_scp -from ..models import IvectorExtractorModel -from ..multiprocessing.ivector import ( - acc_global_stats, - acc_ivector_stats, - extract_ivectors, - gauss_to_post, - gmm_gselect, -) -from ..utils import log_kaldi_errors, parse_logs, thirdparty_binary -from .base import BaseTrainer - -if TYPE_CHECKING: - from ..abc import Dictionary - from ..aligner import PretrainedAligner - from ..config import FeatureConfig - from ..corpus import Corpus - - -IvectorConfigType = Dict[str, Any] - - -__all__ = ["IvectorExtractorTrainer"] - - -class IvectorExtractorTrainer(BaseTrainer, IvectorExtractor): - """ - Trainer for IvectorExtractor - - Attributes - ---------- - ivector_dimension : int - Dimension of the extracted job_name-vector - ivector_period : int - Number of frames between job_name-vector extractions - num_iterations : int - Number of training iterations to perform - num_gselect : int - Gaussian-selection using diagonal model: number of Gaussians to select - posterior_scale : float - Scale on the acoustic posteriors, intended to account for inter-frame correlations - min_post : float - Minimum posterior to use (posteriors below this are pruned out) - subsample : int - Speeds up training; training on every job_name'th feature - max_count : int - The use of this option (e.g. --max-count 100) can make iVectors more consistent for different lengths of utterance, by scaling up the prior term when the data-count exceeds this value. The data-count is after posterior-scaling, so assuming the posterior-scale is 0.1, --max-count 100 starts having effect after 1000 frames, or 10 seconds of data. - """ - - def __init__(self, default_feature_config: FeatureConfig): - super(IvectorExtractorTrainer, self).__init__(default_feature_config) - - self.ubm_num_iterations = 4 - self.ubm_num_gselect = 30 - self.ubm_num_frames = 500000 - self.ubm_num_gaussians = 256 - self.ubm_num_iterations_init = 20 - self.ubm_initial_gaussian_proportion = 0.5 - self.ubm_min_gaussian_weight = 0.0001 - - self.ubm_remove_low_count_gaussians = True - - self.ivector_dimension = 128 - self.num_iterations = 10 - self.num_gselect = 20 - self.posterior_scale = 1.0 - self.silence_weight = 0.0 - self.min_post = 0.025 - self.gaussian_min_count = 100 - self.subsample = 5 - self.max_count = 100 - self.apply_cmn = True - self.previous_align_directory = None - self.dubm_training_complete = False - self.ubm_training_complete = False - - @property - def meta(self) -> MetaDict: - """Metadata information for IvectorExtractor""" - from ..utils import get_mfa_version - - return { - "version": get_mfa_version(), - "ivector_dimension": self.ivector_dimension, - "apply_cmn": self.apply_cmn, - "num_gselect": self.num_gselect, - "min_post": self.min_post, - "posterior_scale": self.posterior_scale, - "features": self.feature_config.params(), - } - - @property - def train_type(self) -> str: - """Training identifier""" - return "ivector" - - @property - def align_directory(self) -> str: - """Alignment directory""" - return self.train_directory - - @property - def ivector_options(self) -> MetaDict: - """Options for ivector training and extracting""" - return { - "subsample": self.subsample, - "num_gselect": self.num_gselect, - "posterior_scale": self.posterior_scale, - "min_post": self.min_post, - "silence_weight": self.silence_weight, - "max_count": self.max_count, - "ivector_dimension": self.ivector_dimension, - "sil_phones": self.dictionary.config.silence_csl, - } - - @property - def current_ie_path(self) -> str: - """Current ivector extractor model path""" - if ( - self.training_complete - or self.iteration is None - or self.iteration > self.num_iterations - ): - return os.path.join(self.working_directory, "final.ie") - return os.path.join(self.working_directory, f"{self.iteration}.ie") - - @property - def next_ie_path(self) -> str: - """Next iteration's ivector extractor model path""" - if self.iteration > self.num_iterations: - return os.path.join(self.working_directory, "final.ie") - return os.path.join(self.working_directory, f"{self.iteration + 1}.ie") - - @property - def dubm_path(self) -> str: - """DUBM model path""" - return os.path.join(self.working_directory, "final.dubm") - - @property - def current_dubm_path(self) -> str: - """Current iteration's DUBM model path""" - if self.dubm_training_complete: - return os.path.join(self.working_directory, "final.dubm") - return os.path.join(self.working_directory, f"{self.iteration}.dubm") - - @property - def next_dubm_path(self) -> str: - """Next iteration's DUBM model path""" - if self.dubm_training_complete: - return os.path.join(self.working_directory, "final.dubm") - return os.path.join(self.working_directory, f"{self.iteration + 1}.dubm") - - @property - def ie_path(self) -> str: - """Ivector extractor model path""" - return os.path.join(self.working_directory, "final.ie") - - @property - def model_path(self) -> str: - """Acoustic model path""" - return os.path.join(self.working_directory, "final.mdl") - - def train_ubm_iteration(self) -> None: - """ - Run an iteration of UBM training - """ - # Accumulate stats - acc_global_stats(self) - self.iteration += 1 - - def finalize_train_ubm(self) -> None: - """Finalize DUBM training""" - final_dubm_path = os.path.join(self.train_directory, "final.dubm") - shutil.copy( - os.path.join(self.train_directory, f"{self.ubm_num_iterations}.dubm"), final_dubm_path - ) - self.iteration = 0 - self.dubm_training_complete = True - - def train_ubm(self) -> None: - """ - Train UBM for ivector extractor - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - # train diag ubm - dirty_path = os.path.join(self.train_directory, "dirty") - final_ubm_path = os.path.join(self.train_directory, "final.ubm") - if os.path.exists(final_ubm_path): - return - try: - begin = time.time() - self.logger.info("Initializing diagonal UBM...") - # Initialize model from E-M in memory - log_directory = os.path.join(self.train_directory, "log") - num_gauss_init = int( - self.ubm_initial_gaussian_proportion * int(self.ubm_num_gaussians) - ) - log_path = os.path.join(log_directory, "gmm_init.log") - feat_name = self.feature_config.feature_id - all_feats_path = os.path.join(self.corpus.output_directory, f"{feat_name}.scp") - feature_string = self.corpus.jobs[0].construct_base_feature_string( - self.corpus, all_feats=True - ) - with open(all_feats_path, "w") as outf: - for i in self.corpus.jobs: - feat_paths = i.construct_path_dictionary(self.data_directory, "feats", "scp") - for p in feat_paths.values(): - with open(p) as inf: - for line in inf: - outf.write(line) - self.iteration = 0 - with open(log_path, "w") as log_file: - gmm_init_proc = subprocess.Popen( - [ - thirdparty_binary("gmm-global-init-from-feats"), - f"--num-threads={self.corpus.num_jobs}", - f"--num-frames={self.ubm_num_frames}", - f"--num_gauss={self.ubm_num_gaussians}", - f"--num_gauss_init={num_gauss_init}", - f"--num_iters={self.ubm_num_iterations_init}", - feature_string, - self.current_dubm_path, - ], - stderr=log_file, - ) - gmm_init_proc.communicate() - # Store Gaussian selection indices on disk - gmm_gselect(self) - final_dubm_path = os.path.join(self.train_directory, "final.dubm") - - if not os.path.exists(final_dubm_path): - self.logger.info("Training diagonal UBM...") - with tqdm(total=self.ubm_num_iterations) as pbar: - while self.iteration < self.ubm_num_iterations + 1: - self.train_ubm_iteration() - pbar.update(1) - self.finalize_train_ubm() - parse_logs(log_directory) - self.logger.info("Finished training UBM!") - self.logger.debug(f"UBM training took {time.time() - begin} seconds") - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - - def init_training( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: Dictionary, - previous_trainer: Optional[PretrainedAligner] = None, - ) -> None: - """ - Initialize ivector extractor training - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - MultispeakerDictionary to use - previous_trainer: :class:`~montreal_forced_aligner.trainers.BaseTrainer`, optional - Previous trainer to initialize from - """ - self._setup_for_init(identifier, temporary_directory, corpus, dictionary, previous_trainer) - done_path = os.path.join(self.train_directory, "done") - if os.path.exists(done_path): - self.logger.info(f"{self.identifier} training already done, skipping initialization.") - return - shutil.copyfile( - previous_trainer.current_model_path, os.path.join(self.train_directory, "final.mdl") - ) - for p in previous_trainer.ali_paths: - shutil.copyfile(p, p.replace(previous_trainer.working_directory, self.train_directory)) - self.corpus.write_utt2spk() - begin = time.time() - self.previous_align_directory = previous_trainer.align_directory - - self.train_ubm() - self.init_ivector_train() - self.logger.info("Initialization complete!") - self.logger.debug(f"Initialization took {time.time() - begin} seconds") - - def init_ivector_train(self) -> None: - """ - Initialize ivector extractor training - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - dirty_path = os.path.join(self.train_directory, "dirty") - try: - init_ie_path = os.path.join(self.train_directory, "0.ie") - if os.path.exists(init_ie_path): - return - self.iteration = 0 - begin = time.time() - # Initialize job_name-vector extractor - log_directory = os.path.join(self.train_directory, "log") - log_path = os.path.join(log_directory, "init.log") - diag_ubm_path = os.path.join(self.train_directory, "final.dubm") - full_ubm_path = os.path.join(self.train_directory, "final.ubm") - with open(log_path, "w") as log_file: - subprocess.call( - [thirdparty_binary("gmm-global-to-fgmm"), diag_ubm_path, full_ubm_path], - stderr=log_file, - ) - subprocess.call( - [ - thirdparty_binary("ivector-extractor-init"), - f"--ivector-dim={self.ivector_dimension}", - "--use-weights=false", - full_ubm_path, - self.current_ie_path, - ], - stderr=log_file, - ) - - # Do Gaussian selection and posterior extraction - gauss_to_post(self) - parse_logs(log_directory) - self.logger.debug(f"Initialization ivectors took {time.time() - begin} seconds") - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - - def align(self, subset: Optional[int] = None): - """Overwrite align function to export IvectorExtractor to align directory""" - self.save(os.path.join(self.align_directory, "ivector_extractor.zip")) - - def extract_ivectors(self) -> None: - """ - Extract ivectors for the corpus - """ - extract_ivectors(self) - - def training_iteration(self): - """ - Run an iteration of training - """ - if os.path.exists(self.next_ie_path): - self.iteration += 1 - return - # Accumulate stats and sum - acc_ivector_stats(self) - - self.iteration += 1 - - def finalize_training(self): - """ - Finalize ivector extractor training - """ - import numpy as np - from joblib import dump - from sklearn.naive_bayes import GaussianNB - - # Rename to final - shutil.copy( - os.path.join(self.train_directory, f"{self.num_iterations}.ie"), - os.path.join(self.train_directory, "final.ie"), - ) - self.training_complete = True - self.iteration = None - extract_ivectors(self) - x = [] - y = [] - speakers = sorted(self.corpus.speakers.keys()) - for j in self.corpus.jobs: - arguments = j.extract_ivector_arguments(self) - for ivector_path in arguments.ivector_paths.values(): - ivec = load_scp(ivector_path) - for utt, ivector in ivec.items(): - ivector = [float(x) for x in ivector] - s = self.corpus.utterances[utt].speaker.name - s_ind = speakers.index(s) - y.append(s_ind) - x.append(ivector) - x = np.array(x) - y = np.array(y) - clf = GaussianNB() - clf.fit(x, y) - clf_param_path = os.path.join(self.train_directory, "speaker_classifier.mdl") - dump(clf, clf_param_path) - classes_path = os.path.join(self.train_directory, "speaker_labels.txt") - with open(classes_path, "w", encoding="utf8") as f: - for i, s in enumerate(speakers): - f.write(f"{s} {i}\n") - - def save(self, path: str, root_directory: Optional[str] = None): - """ - Output IvectorExtractor model - - Parameters - ---------- - path : str - Path to save acoustic model and dictionary - root_directory : str or None - Path for root directory of temporary files - """ - directory, filename = os.path.split(path) - basename, _ = os.path.splitext(filename) - ivector_extractor = IvectorExtractorModel.empty(basename, root_directory) - ivector_extractor.add_meta_file(self) - ivector_extractor.add_model(self.train_directory) - os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(path) - ivector_extractor.dump(basename) diff --git a/montreal_forced_aligner/trainers/lda.py b/montreal_forced_aligner/trainers/lda.py deleted file mode 100644 index 7f436d31..00000000 --- a/montreal_forced_aligner/trainers/lda.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Class definitions for LDA trainer""" -from __future__ import annotations - -import os -import time -from typing import TYPE_CHECKING, Optional - -from ..abc import MetaDict, Trainer -from ..exceptions import KaldiProcessingError -from ..multiprocessing import ( - acc_stats, - align, - calc_lda_mllt, - compute_alignment_improvement, - lda_acc_stats, -) -from ..utils import log_kaldi_errors, parse_logs -from .triphone import TriphoneTrainer - -if TYPE_CHECKING: - from ..config import FeatureConfig - from ..corpus import Corpus - from ..dictionary import MultispeakerDictionary - - -__all__ = ["LdaTrainer"] - - -class LdaTrainer(TriphoneTrainer): - """ - - Configuration class for LDA+MLLT training - - Attributes - ---------- - lda_dimension : int - Dimensionality of the LDA matrix - mllt_iterations : list - List of iterations to perform MLLT estimation - random_prune : float - This is approximately the ratio by which we will speed up the - LDA and MLLT calculations via randomized pruning - """ - - def __init__(self, default_feature_config: FeatureConfig): - super(LdaTrainer, self).__init__(default_feature_config) - self.lda_dimension = 40 - self.mllt_iterations = [] - max_mllt_iter = int(self.num_iterations / 2) - 1 - for i in range(1, max_mllt_iter): - if i < max_mllt_iter / 2 and i % 2 == 0: - self.mllt_iterations.append(i) - self.mllt_iterations.append(max_mllt_iter) - if not self.mllt_iterations: - self.mllt_iterations = range(1, 4) - self.random_prune = 4.0 - - self.feature_config.lda = True - self.feature_config.deltas = True - self.uses_splices = True - - def compute_calculated_properties(self) -> None: - """Generate realignment iterations, MLLT estimation iterations, and initial gaussians based on configuration""" - super(LdaTrainer, self).compute_calculated_properties() - self.mllt_iterations = [] - max_mllt_iter = int(self.num_iterations / 2) - 1 - for i in range(1, max_mllt_iter): - if i < max_mllt_iter / 2 and i % 2 == 0: - self.mllt_iterations.append(i) - self.mllt_iterations.append(max_mllt_iter) - - @property - def train_type(self) -> str: - """Training identifier""" - return "lda" - - @property - def lda_options(self) -> MetaDict: - """Options for computing LDA""" - return { - "lda_dimension": self.lda_dimension, - "boost_silence": self.boost_silence, - "random_prune": self.random_prune, - "silence_csl": self.dictionary.config.silence_csl, - } - - def init_training( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: MultispeakerDictionary, - previous_trainer: Optional[Trainer], - ): - """ - Initialize LDA training - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: Dictionary - Pronunciation dictionary to use - previous_trainer: Trainer, optional - Previous trainer to initialize from - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - self._setup_for_init(identifier, temporary_directory, corpus, dictionary, previous_trainer) - done_path = os.path.join(self.train_directory, "done") - dirty_path = os.path.join(self.train_directory, "dirty") - if os.path.exists(done_path): - self.logger.info("{self.identifier} training already done, skipping initialization.") - return - begin = time.time() - try: - self.feature_config.directory = None - lda_acc_stats(self) - self.feature_config.directory = self.train_directory - except Exception as e: - with open(dirty_path, "w") as _: - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - self._setup_tree() - self.iteration = 1 - self.logger.info("Initialization complete!") - self.logger.debug(f"Initialization took {time.time() - begin} seconds") - - def training_iteration(self): - """ - Run a single training iteration - """ - if os.path.exists(self.next_model_path): - return - if self.iteration in self.realignment_iterations: - align(self) - if self.debug: - compute_alignment_improvement(self) - if self.iteration in self.mllt_iterations: - calc_lda_mllt(self) - - acc_stats(self) - parse_logs(self.log_directory) - if self.iteration < self.final_gaussian_iteration: - self.increment_gaussians() - self.iteration += 1 diff --git a/montreal_forced_aligner/trainers/monophone.py b/montreal_forced_aligner/trainers/monophone.py deleted file mode 100644 index dc35d686..00000000 --- a/montreal_forced_aligner/trainers/monophone.py +++ /dev/null @@ -1,160 +0,0 @@ -"""Class definitions for Monophone trainer""" -from __future__ import annotations - -import os -import re -import subprocess -import time -from typing import TYPE_CHECKING, Optional - -from ..exceptions import KaldiProcessingError -from ..multiprocessing import compile_train_graphs, mono_align_equal -from ..utils import log_kaldi_errors, parse_logs, thirdparty_binary -from .base import BaseTrainer - -if TYPE_CHECKING: - from ..config import FeatureConfig - from ..corpus import Corpus - from ..dictionary import MultispeakerDictionary - - -__all__ = ["MonophoneTrainer"] - - -class MonophoneTrainer(BaseTrainer): - """ - Configuration class for monophone training - - - Attributes - ---------- - initial_gaussians : int - Number of gaussians to begin training - """ - - def __init__(self, default_feature_config: FeatureConfig): - super(MonophoneTrainer, self).__init__(default_feature_config) - self.initial_gaussians = 135 - self.compute_calculated_properties() - - def compute_calculated_properties(self) -> None: - """Generate realignment iterations and initial gaussians based on configuration""" - for i in range(1, self.num_iterations): - if i <= int(self.num_iterations / 4): - self.realignment_iterations.append(i) - elif i <= int(self.num_iterations * 2 / 4): - if i - self.realignment_iterations[-1] > 1: - self.realignment_iterations.append(i) - else: - if i - self.realignment_iterations[-1] > 2: - self.realignment_iterations.append(i) - - @property - def train_type(self) -> str: - """Training identifier""" - return "mono" - - @property - def phone_type(self) -> str: - """Phone type""" - return "monophone" - - def init_training( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: MultispeakerDictionary, - previous_trainer: Optional[BaseTrainer] = None, - ) -> None: - """ - Initialize monophone training - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - MultispeakerDictionary to use - previous_trainer: Trainer, optional - Previous trainer to initialize from - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - self._setup_for_init(identifier, temporary_directory, corpus, dictionary, previous_trainer) - done_path = os.path.join(self.train_directory, "done") - dirty_path = os.path.join(self.train_directory, "dirty") - if os.path.exists(done_path): - self.logger.info(f"{self.identifier} training already done, skipping initialization.") - return - begin = time.time() - self.iteration = 0 - tree_path = os.path.join(self.train_directory, "tree") - - try: - feat_dim = corpus.get_feat_dim() - - feature_string = corpus.jobs[0].construct_base_feature_string(corpus) - shared_phones_path = os.path.join( - dictionary.get_dictionary("default").phones_dir, "sets.int" - ) - init_log_path = os.path.join(self.log_directory, "init.log") - temp_feats_path = os.path.join(self.train_directory, "temp_feats") - with open(init_log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("subset-feats"), - "--n=10", - feature_string, - f"ark:{temp_feats_path}", - ], - stderr=log_file, - ) - subprocess.call( - [ - thirdparty_binary("gmm-init-mono"), - f"--shared-phones={shared_phones_path}", - f"--train-feats=ark:{temp_feats_path}", - os.path.join( - dictionary.get_dictionary("default").output_directory, "topo" - ), - str(feat_dim), - self.current_model_path, - tree_path, - ], - stderr=log_file, - ) - proc = subprocess.Popen( - [thirdparty_binary("gmm-info"), "--print-args=false", self.current_model_path], - stderr=log_file, - stdout=subprocess.PIPE, - ) - stdout, stderr = proc.communicate() - num = stdout.decode("utf8") - matches = re.search(r"gaussians (\d+)", num) - num_gauss = int(matches.groups()[0]) - if os.path.exists(self.current_model_path): - os.remove(init_log_path) - os.remove(temp_feats_path) - self.initial_gaussians = num_gauss - self.current_gaussians = num_gauss - compile_train_graphs(self) - mono_align_equal(self) - self.iteration = 1 - parse_logs(self.log_directory) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - self.logger.info("Initialization complete!") - self.logger.debug(f"Initialization took {time.time() - begin} seconds") diff --git a/montreal_forced_aligner/trainers/sat.py b/montreal_forced_aligner/trainers/sat.py deleted file mode 100644 index 8da160e9..00000000 --- a/montreal_forced_aligner/trainers/sat.py +++ /dev/null @@ -1,370 +0,0 @@ -"""Class definitions for Speaker Adapted Triphone trainer""" -from __future__ import annotations - -import os -import shutil -import subprocess -import time -from typing import TYPE_CHECKING, Optional - -from ..abc import MetaDict -from ..exceptions import KaldiProcessingError -from ..multiprocessing import ( - acc_stats, - align, - calc_fmllr, - compile_information, - compile_train_graphs, - compute_alignment_improvement, - convert_alignments, - create_align_model, - tree_stats, -) -from ..utils import log_kaldi_errors, parse_logs, thirdparty_binary -from .triphone import TriphoneTrainer - -if TYPE_CHECKING: - from ..abc import Dictionary, Trainer - from ..config import FeatureConfig - from ..corpus import Corpus - - -__all__ = ["SatTrainer"] - - -class SatTrainer(TriphoneTrainer): - """ - - Configuration class for speaker adapted training (SAT) - - Attributes - ---------- - fmllr_update_type : str - Type of fMLLR estimation, defaults to ``'full'`` - fmllr_iterations : list - List of iterations to perform fMLLR estimation - silence_weight : float - Weight on silence in fMLLR estimation - """ - - def __init__(self, default_feature_config: FeatureConfig): - super(SatTrainer, self).__init__(default_feature_config) - self.fmllr_update_type = "full" - self.fmllr_iterations = [] - max_fmllr_iter = int(self.num_iterations / 2) - 1 - for i in range(1, max_fmllr_iter): - if i < max_fmllr_iter / 2 and i % 2 == 0: - self.fmllr_iterations.append(i) - self.fmllr_iterations.append(max_fmllr_iter) - self.silence_weight = 0.0 - self.feature_config.fmllr = True - self.initial_fmllr = True - self.ensure_train = True - - def compute_calculated_properties(self) -> None: - """Generate realignment iterations, initial gaussians, and fMLLR iteraction based on configuration""" - super(SatTrainer, self).compute_calculated_properties() - self.fmllr_iterations = [] - max_fmllr_iter = int(self.num_iterations / 2) - 1 - for i in range(1, max_fmllr_iter): - if i < max_fmllr_iter / 2 and i % 2 == 0: - self.fmllr_iterations.append(i) - self.fmllr_iterations.append(max_fmllr_iter) - - @property - def train_type(self) -> str: - """Training identifier""" - return "sat" - - @property - def fmllr_options(self) -> MetaDict: - """Options for calculating fMLLR transforms""" - return { - "fmllr_update_type": self.fmllr_update_type, - "debug": self.debug, - "initial": self.initial_fmllr, - "silence_csl": self.dictionary.config.silence_csl, - } - - @property - def working_directory(self) -> str: - """Current working directory""" - if self.ensure_train: - return self.train_directory - return super().working_directory - - @property - def working_log_directory(self) -> str: - """Current log directory""" - if self.ensure_train: - return self.log_directory - return super().working_log_directory - - def finalize_training(self) -> None: - """ - Finalize training and create a speaker independent model for initial alignment - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - try: - super().finalize_training() - create_align_model(self) - self.ensure_train = False - shutil.copyfile( - os.path.join(self.train_directory, "final.alimdl"), - os.path.join(self.align_directory, "final.alimdl"), - ) - except Exception as e: - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - - def training_iteration(self) -> None: - """ - Run a single training iteration - """ - if os.path.exists(self.next_model_path): - self.iteration += 1 - return - if self.iteration in self.realignment_iterations: - align(self) - if self.debug: - compute_alignment_improvement(self) - if self.iteration in self.fmllr_iterations: - calc_fmllr(self) - - acc_stats(self) - parse_logs(self.log_directory) - if self.iteration < self.final_gaussian_iteration: - self.increment_gaussians() - self.iteration += 1 - - def align(self, subset: Optional[int] = None) -> None: - """ - Align a given subset of the corpus - - Parameters - ---------- - subset: int, optional - Number of utterances to select for the aligned subset - - Raises - ------ - KaldiProcessingError - If there were any errors in running Kaldi binaries - """ - if not os.path.exists(self.align_directory): - self.finalize_training() - dirty_path = os.path.join(self.align_directory, "dirty") - if os.path.exists(dirty_path): # if there was an error, let's redo from scratch - shutil.rmtree(self.align_directory) - done_path = os.path.join(self.align_directory, "done") - if not os.path.exists(done_path): - message = f"Generating alignments using {self.identifier} models" - if subset: - message += f" using {subset} utterances..." - else: - message += " for the whole corpus..." - self.logger.info(message) - begin = time.time() - if subset is None: - self.data_directory = self.corpus.split_directory - else: - self.data_directory = self.corpus.subset_directory(subset) - try: - self.speaker_independent = True - self.initial_fmllr = True - compile_train_graphs(self) - align(self) - - unaligned, average_log_like = compile_information(self) - self.logger.debug( - f"Before SAT, average per frame likelihood (this might not actually mean anything): {average_log_like}" - ) - - if self.speaker_independent: - calc_fmllr(self) - self.speaker_independent = False - self.initial_fmllr = False - align(self) - self.save(os.path.join(self.align_directory, "acoustic_model.zip")) - - unaligned, average_log_like = compile_information(self) - self.logger.debug( - f"Following SAT, average per frame likelihood (this might not actually mean anything): {average_log_like}" - ) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - with open(done_path, "w"): - pass - self.logger.debug(f"Alignment took {time.time() - begin} seconds") - else: - self.logger.info(f"Alignments using {self.identifier} models already done") - - def init_training( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: Dictionary, - previous_trainer: Optional[Trainer], - ) -> None: - """ - Initialize speaker-adapted triphone training - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - MultispeakerDictionary to use - previous_trainer: Trainer, optional - Previous trainer to initialize from - """ - self.feature_config.fmllr = False - self._setup_for_init(identifier, temporary_directory, corpus, dictionary, previous_trainer) - done_path = os.path.join(self.train_directory, "done") - dirty_path = os.path.join(self.train_directory, "dirty") - self.feature_config.fmllr = True - if os.path.exists(done_path): - self.logger.info(f"{self.identifier} training already done, skipping initialization.") - return - if os.path.exists(os.path.join(self.train_directory, "1.mdl")): - return - begin = time.time() - self.logger.info("Initializing speaker-adapted triphone training...") - align_directory = previous_trainer.align_directory - try: - if os.path.exists(os.path.join(align_directory, "lda.mat")): - shutil.copyfile( - os.path.join(align_directory, "lda.mat"), - os.path.join(self.train_directory, "lda.mat"), - ) - tree_stats(self) - log_path = os.path.join(self.log_directory, "questions.log") - tree_path = os.path.join(self.train_directory, "tree") - treeacc_path = os.path.join(self.train_directory, "treeacc") - sets_int_path = os.path.join(self.dictionary.phones_dir, "sets.int") - roots_int_path = os.path.join(self.dictionary.phones_dir, "roots.int") - extra_question_int_path = os.path.join( - self.dictionary.phones_dir, "extra_questions.int" - ) - topo_path = self.dictionary.topo_path - questions_path = os.path.join(self.train_directory, "questions.int") - questions_qst_path = os.path.join(self.train_directory, "questions.qst") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("cluster-phones"), - treeacc_path, - sets_int_path, - questions_path, - ], - stderr=log_file, - ) - - with open(extra_question_int_path, "r") as in_file, open( - questions_path, "a" - ) as out_file: - for line in in_file: - out_file.write(line) - - log_path = os.path.join(self.log_directory, "compile_questions.log") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("compile-questions"), - topo_path, - questions_path, - questions_qst_path, - ], - stderr=log_file, - ) - - log_path = os.path.join(self.log_directory, "build_tree.log") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("build-tree"), - "--verbose=1", - f"--max-leaves={self.initial_gaussians}", - f"--cluster-thresh={self.cluster_threshold}", - treeacc_path, - roots_int_path, - questions_qst_path, - topo_path, - tree_path, - ], - stderr=log_file, - ) - - log_path = os.path.join(self.log_directory, "init_model.log") - occs_path = os.path.join(self.train_directory, "0.occs") - mdl_path = os.path.join(self.train_directory, "0.mdl") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("gmm-init-model"), - f"--write-occs={occs_path}", - tree_path, - treeacc_path, - topo_path, - mdl_path, - ], - stderr=log_file, - ) - - log_path = os.path.join(self.log_directory, "mixup.log") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("gmm-mixup"), - f"--mix-up={self.initial_gaussians}", - mdl_path, - occs_path, - mdl_path, - ], - stderr=log_file, - ) - os.remove(treeacc_path) - - compile_train_graphs(self) - - convert_alignments(self) - - if os.path.exists(os.path.join(align_directory, "trans.0.ark")): - for j in self.corpus.jobs: - for path in j.construct_path_dictionary( - align_directory, "trans", "ark" - ).values(): - shutil.copy(path, path.replace(align_directory, self.train_directory)) - else: - - calc_fmllr(self) - self.initial_fmllr = False - self.iteration = 1 - os.rename(occs_path, os.path.join(self.train_directory, "1.occs")) - os.rename(mdl_path, os.path.join(self.train_directory, "1.mdl")) - parse_logs(self.log_directory) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - self.logger.info("Initialization complete!") - self.logger.debug(f"Initialization took {time.time() - begin} seconds") diff --git a/montreal_forced_aligner/trainers/triphone.py b/montreal_forced_aligner/trainers/triphone.py deleted file mode 100644 index 7c7efff5..00000000 --- a/montreal_forced_aligner/trainers/triphone.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Class definitions for TriphoneTrainer""" -from __future__ import annotations - -import os -import subprocess -import time -from typing import TYPE_CHECKING, Optional - -from ..exceptions import KaldiProcessingError -from ..multiprocessing import compile_train_graphs, convert_alignments, tree_stats -from ..utils import log_kaldi_errors, parse_logs, thirdparty_binary -from .base import BaseTrainer - -if TYPE_CHECKING: - from ..abc import Dictionary, Trainer - from ..config import FeatureConfig - from ..corpus import Corpus - - -__all__ = ["TriphoneTrainer"] - - -class TriphoneTrainer(BaseTrainer): - """ - Configuration class for triphone training - - Attributes - ---------- - num_iterations : int - Number of training iterations to perform, defaults to 40 - num_leaves : int - Number of states in the decision tree, defaults to 1000 - max_gaussians : int - Number of gaussians in the decision tree, defaults to 10000 - cluster_threshold : int - For build-tree control final bottom-up clustering of leaves, defaults to 100 - """ - - def __init__(self, default_feature_config: FeatureConfig): - super(TriphoneTrainer, self).__init__(default_feature_config) - - self.num_iterations = 35 - self.num_leaves = 1000 - self.max_gaussians = 10000 - self.cluster_threshold = -1 - self.compute_calculated_properties() - - def compute_calculated_properties(self) -> None: - """Generate realignment iterations and initial gaussians based on configuration""" - for i in range(0, self.num_iterations, 10): - if i == 0: - continue - self.realignment_iterations.append(i) - self.initial_gaussians = self.num_leaves - self.current_gaussians = self.num_leaves - - @property - def train_type(self) -> str: - """Training identifier""" - return "tri" - - @property - def phone_type(self) -> str: - """Phone type""" - return "triphone" - - def _setup_tree(self) -> None: - """ - Set up the tree for the triphone model - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - dirty_path = os.path.join(self.train_directory, "dirty") - try: - - tree_stats(self) - log_path = os.path.join(self.log_directory, "questions.log") - tree_path = os.path.join(self.train_directory, "tree") - treeacc_path = os.path.join(self.train_directory, "treeacc") - sets_int_path = os.path.join(self.dictionary.phones_dir, "sets.int") - roots_int_path = os.path.join(self.dictionary.phones_dir, "roots.int") - extra_question_int_path = os.path.join( - self.dictionary.phones_dir, "extra_questions.int" - ) - topo_path = self.dictionary.topo_path - questions_path = os.path.join(self.train_directory, "questions.int") - questions_qst_path = os.path.join(self.train_directory, "questions.qst") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("cluster-phones"), - treeacc_path, - sets_int_path, - questions_path, - ], - stderr=log_file, - ) - - with open(extra_question_int_path, "r") as inf, open(questions_path, "a") as outf: - for line in inf: - outf.write(line) - - log_path = os.path.join(self.log_directory, "compile_questions.log") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("compile-questions"), - topo_path, - questions_path, - questions_qst_path, - ], - stderr=log_file, - ) - - log_path = os.path.join(self.log_directory, "build_tree.log") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("build-tree"), - "--verbose=1", - f"--max-leaves={self.initial_gaussians}", - f"--cluster-thresh={self.cluster_threshold}", - treeacc_path, - roots_int_path, - questions_qst_path, - topo_path, - tree_path, - ], - stderr=log_file, - ) - - log_path = os.path.join(self.log_directory, "init_model.log") - occs_path = os.path.join(self.train_directory, "0.occs") - mdl_path = self.current_model_path - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("gmm-init-model"), - f"--write-occs={occs_path}", - tree_path, - treeacc_path, - topo_path, - mdl_path, - ], - stderr=log_file, - ) - - log_path = os.path.join(self.log_directory, "mixup.log") - with open(log_path, "w") as log_file: - subprocess.call( - [ - thirdparty_binary("gmm-mixup"), - f"--mix-up={self.initial_gaussians}", - mdl_path, - occs_path, - mdl_path, - ], - stderr=log_file, - ) - os.remove(treeacc_path) - parse_logs(self.log_directory) - - compile_train_graphs(self) - - convert_alignments(self) - os.rename(occs_path, self.next_occs_path) - os.rename(mdl_path, self.next_model_path) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - - def init_training( - self, - identifier: str, - temporary_directory: str, - corpus: Corpus, - dictionary: Dictionary, - previous_trainer: Optional[Trainer], - ): - """ - Initialize triphone training - - Parameters - ---------- - identifier: str - Identifier for the training block - temporary_directory: str - Root temporary directory to save - corpus: :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to use - dictionary: MultispeakerDictionary - Dictionary to use - previous_trainer: Trainer, optional - Previous trainer to initialize from - """ - self._setup_for_init(identifier, temporary_directory, corpus, dictionary, previous_trainer) - done_path = os.path.join(self.train_directory, "done") - if os.path.exists(done_path): - self.logger.info(f"{self.identifier} training already done, skipping initialization.") - return - begin = time.time() - self._setup_tree() - - self.iteration = 1 - self.logger.info("Initialization complete!") - self.logger.debug(f"Initialization took {time.time() - begin} seconds") diff --git a/montreal_forced_aligner/transcriber.py b/montreal_forced_aligner/transcriber.py deleted file mode 100644 index b7a8a289..00000000 --- a/montreal_forced_aligner/transcriber.py +++ /dev/null @@ -1,401 +0,0 @@ -""" -Transcription -============= - -""" -from __future__ import annotations - -import multiprocessing as mp -import os -import shutil -import subprocess -from typing import TYPE_CHECKING, Optional, Tuple - -from .abc import Transcriber as ABCTranscriber -from .config import TEMP_DIR -from .exceptions import KaldiProcessingError -from .helper import score -from .multiprocessing.transcription import ( - create_hclgs, - score_transcriptions, - transcribe, - transcribe_fmllr, -) -from .utils import log_kaldi_errors, thirdparty_binary - -if TYPE_CHECKING: - from logging import Logger - - from .config.transcribe_config import TranscribeConfig - from .corpus import Corpus - from .dictionary import MultispeakerDictionary - from .models import AcousticModel, LanguageModel - -__all__ = ["Transcriber"] - - -class Transcriber(ABCTranscriber): - """ - Class for performing transcription. - - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus to transcribe - dictionary: :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - Pronunciation dictionary to use as a lexicon - acoustic_model : :class:`~montreal_forced_aligner.models.AcousticModel` - Acoustic model to use - language_model : :class:`~montreal_forced_aligner.models.LanguageModel` - Language model to use - transcribe_config : :class:`~montreal_forced_aligner.config.TranscribeConfig` - Language model to use - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - debug : bool - Flag for running in debug mode, defaults to false - verbose : bool - Flag for running in verbose mode, defaults to false - evaluation_mode : bool - Flag for running in evaluation mode, defaults to false - logger : :class:`~logging.Logger`, optional - Logger to use - """ - - min_language_model_weight = 7 - max_language_model_weight = 17 - word_insertion_penalties = [0, 0.5, 1.0] - - def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - acoustic_model: AcousticModel, - language_model: LanguageModel, - transcribe_config: TranscribeConfig, - temp_directory: Optional[str] = None, - debug: bool = False, - verbose: bool = False, - evaluation_mode: bool = False, - logger: Optional[Logger] = None, - ): - self.logger = logger - self.corpus = corpus - self.dictionary = dictionary - self.acoustic_model = acoustic_model - self.language_model = language_model - self.transcribe_config = transcribe_config - - if not temp_directory: - temp_directory = TEMP_DIR - self.temp_directory = temp_directory - self.verbose = verbose - self.debug = debug - self.evaluation_mode = evaluation_mode - self.acoustic_model.export_model(self.model_directory) - self.acoustic_model.export_model(self.working_directory) - self.log_dir = os.path.join(self.transcribe_directory, "log") - self.uses_voiced = False - self.uses_splices = False - self.uses_cmvn = True - self.speaker_independent = True - os.makedirs(self.log_dir, exist_ok=True) - self.setup() - - @property - def transcribe_directory(self) -> str: - """Temporary directory root for all transcription""" - return os.path.join(self.temp_directory, "transcribe") - - @property - def evaluation_directory(self): - """Evaluation directory path for the current language model weight and word insertion penalty""" - eval_string = f"eval_{self.transcribe_config.language_model_weight}_{self.transcribe_config.word_insertion_penalty}" - path = os.path.join(self.working_directory, eval_string) - os.makedirs(path, exist_ok=True) - return path - - @property - def working_directory(self) -> str: - """Current working directory""" - return self.transcribe_directory - - @property - def evaluation_log_directory(self) -> str: - """Log directory for the current evaluation""" - return os.path.join(self.evaluation_directory, "log") - - @property - def working_log_directory(self) -> str: - """Log directory for the current state""" - return os.path.join(self.working_directory, "log") - - @property - def data_directory(self) -> str: - """Corpus data directory""" - return self.corpus.split_directory - - @property - def model_directory(self) -> str: - """Model directory for the transcriber""" - return os.path.join(self.temp_directory, "models") - - @property - def model_path(self) -> str: - """Acoustic model file path""" - return os.path.join(self.working_directory, "final.mdl") - - @property - def alignment_model_path(self) -> str: - """Alignment (speaker-independent) acoustic model file path""" - path = os.path.join(self.working_directory, "final.alimdl") - if os.path.exists(path): - return path - return self.model_path - - @property - def fmllr_options(self): - """Options for computing fMLLR transforms""" - data = self.transcribe_config.fmllr_options - data["sil_phones"] = self.dictionary.config.silence_csl - return data - - @property - def hclg_options(self): - """Options for constructing HCLG FSTs""" - context_width, central_pos = self.get_tree_info() - return { - "context_width": context_width, - "central_pos": central_pos, - "self_loop_scale": self.transcribe_config.self_loop_scale, - "transition_scale": self.transcribe_config.transition_scale, - } - - def get_tree_info(self) -> Tuple[int, int]: - """ - Get the context width and central position for the acoustic model - - Returns - ------- - int - Context width - int - Central position - """ - tree_proc = subprocess.Popen( - [thirdparty_binary("tree-info"), os.path.join(self.model_directory, "tree")], - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - stdout, _ = tree_proc.communicate() - context_width = 1 - central_pos = 0 - for line in stdout.split("\n"): - text = line.strip().split(" ") - if text[0] == "context-width": - context_width = int(text[1]) - elif text[0] == "central-position": - central_pos = int(text[1]) - return context_width, central_pos - - def setup(self) -> None: - """ - Sets up the corpus and transcriber - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - dirty_path = os.path.join(self.model_directory, "dirty") - - if os.path.exists(dirty_path): # if there was an error, let's redo from scratch - shutil.rmtree(self.model_directory) - log_dir = os.path.join(self.model_directory, "log") - os.makedirs(log_dir, exist_ok=True) - self.dictionary.write(write_disambiguation=True) - for dict_name, output_directory in self.dictionary.output_paths.items(): - words_path = os.path.join(self.model_directory, f"words.{dict_name}.txt") - shutil.copyfile(os.path.join(output_directory, "words.txt"), words_path) - self.corpus.initialize_corpus(self.dictionary, self.acoustic_model.feature_config) - - big_arpa_path = self.language_model.carpa_path - small_arpa_path = self.language_model.small_arpa_path - medium_arpa_path = self.language_model.medium_arpa_path - if not os.path.exists(small_arpa_path) or not os.path.exists(medium_arpa_path): - self.logger.info("Parsing large ngram model...") - mod_path = os.path.join(self.model_directory, "base_lm.mod") - new_carpa_path = os.path.join(self.model_directory, "base_lm.arpa") - with open(big_arpa_path, "r", encoding="utf8") as inf, open( - new_carpa_path, "w", encoding="utf8" - ) as outf: - for line in inf: - outf.write(line.lower()) - big_arpa_path = new_carpa_path - subprocess.call(["ngramread", "--ARPA", big_arpa_path, mod_path]) - - if not os.path.exists(small_arpa_path): - self.logger.info( - "Generating small model from the large ARPA with a pruning threshold of 3e-7" - ) - prune_thresh_small = 0.0000003 - small_mod_path = mod_path.replace(".mod", "_small.mod") - subprocess.call( - [ - "ngramshrink", - "--method=relative_entropy", - f"--theta={prune_thresh_small}", - mod_path, - small_mod_path, - ] - ) - subprocess.call(["ngramprint", "--ARPA", small_mod_path, small_arpa_path]) - - if not os.path.exists(medium_arpa_path): - self.logger.info( - "Generating medium model from the large ARPA with a pruning threshold of 1e-7" - ) - prune_thresh_medium = 0.0000001 - med_mod_path = mod_path.replace(".mod", "_med.mod") - subprocess.call( - [ - "ngramshrink", - "--method=relative_entropy", - f"--theta={prune_thresh_medium}", - mod_path, - med_mod_path, - ] - ) - subprocess.call(["ngramprint", "--ARPA", med_mod_path, medium_arpa_path]) - try: - create_hclgs(self) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - - def transcribe(self) -> None: - """ - Transcribe the corpus - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - self.logger.info("Beginning transcription...") - dirty_path = os.path.join(self.transcribe_directory, "dirty") - if os.path.exists(dirty_path): - shutil.rmtree(self.transcribe_directory, ignore_errors=True) - os.makedirs(self.log_dir, exist_ok=True) - try: - transcribe(self) - if ( - self.acoustic_model.feature_config.fmllr - and not self.transcribe_config.ignore_speakers - and self.transcribe_config.fmllr - ): - self.logger.info("Performing speaker adjusted transcription...") - transcribe_fmllr(self) - score_transcriptions(self) - except Exception as e: - with open(dirty_path, "w"): - pass - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) - raise - - def evaluate(self): - """ - Evaluates the transcripts if there are reference transcripts - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - self.logger.info("Evaluating transcripts...") - self._load_transcripts() - # Sentence-level measures - - correct = 0 - incorrect = 0 - # Word-level measures - total_edits = 0 - total_length = 0 - issues = [] - with mp.Pool(self.corpus.num_jobs) as pool: - to_comp = [] - for utt_name, utterance in self.corpus.utterances.items(): - g = utterance.text.split() - if not utterance.transcription_text: - incorrect += 1 - total_edits += len(g) - total_length += len(g) - - h = utterance.transcription_text.split() - if g != h: - issues.append((utt_name, g, h)) - to_comp.append((g, h)) - gen = pool.starmap(score, to_comp) - for (edits, length) in gen: - if edits == 0: - correct += 1 - else: - incorrect += 1 - total_edits += edits - total_length += length - ser = 100 * incorrect / (correct + incorrect) - wer = 100 * total_edits / total_length - output_path = os.path.join(self.evaluation_directory, "transcription_issues.csv") - with open(output_path, "w", encoding="utf8") as f: - for utt, g, h in issues: - g = " ".join(g) - h = " ".join(h) - f.write(f"{utt},{g},{h}\n") - self.logger.info(f"SER: {ser:.2f}%, WER: {wer:.2f}%") - return ser, wer - - def _load_transcripts(self): - """Load transcripts from Kaldi temporary files""" - for j in self.corpus.jobs: - score_arguments = j.score_arguments(self) - for tra_path in score_arguments.tra_paths.values(): - - with open(tra_path, "r", encoding="utf8") as f: - for line in f: - t = line.strip().split(" ") - utt = t[0] - utterance = self.corpus.utterances[utt] - speaker = utterance.speaker - lookup = speaker.dictionary.reversed_word_mapping - ints = t[1:] - if not ints: - continue - transcription = [] - for i in ints: - transcription.append(lookup[int(i)]) - utterance.transcription_text = " ".join(transcription) - - def export_transcriptions(self, output_directory: str) -> None: - """ - Export transcriptions - - Parameters - ---------- - output_directory: str - Directory to save transcriptions - """ - backup_output_directory = None - if not self.transcribe_config.overwrite: - backup_output_directory = os.path.join(self.transcribe_directory, "transcriptions") - os.makedirs(backup_output_directory, exist_ok=True) - self._load_transcripts() - for file in self.corpus.files.values(): - file.save(output_directory, backup_output_directory) diff --git a/montreal_forced_aligner/transcription/__init__.py b/montreal_forced_aligner/transcription/__init__.py new file mode 100644 index 00000000..41002280 --- /dev/null +++ b/montreal_forced_aligner/transcription/__init__.py @@ -0,0 +1,3 @@ +from montreal_forced_aligner.transcription.transcriber import Transcriber + +__all__ = ["Transcriber", "transcriber"] diff --git a/montreal_forced_aligner/multiprocessing/transcription.py b/montreal_forced_aligner/transcription/multiprocessing.py similarity index 72% rename from montreal_forced_aligner/multiprocessing/transcription.py rename to montreal_forced_aligner/transcription/multiprocessing.py index 76754775..05990dd8 100644 --- a/montreal_forced_aligner/multiprocessing/transcription.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -7,21 +7,16 @@ import os import re -import shutil import subprocess import sys -from typing import TYPE_CHECKING, Dict, List, Optional, TextIO, Union +from typing import TYPE_CHECKING, NamedTuple, TextIO from ..abc import MetaDict -from ..exceptions import KaldiProcessingError from ..utils import thirdparty_binary -from .helper import run_mp, run_non_mp if TYPE_CHECKING: - from ..dictionary import MappingType - from ..transcriber import Transcriber + from ..abc import MappingType - DictionaryNames = Optional[Union[List[str], str]] __all__ = [ "compose_g", @@ -29,10 +24,6 @@ "compose_clg", "compose_hclg", "compose_g_carpa", - "create_hclgs", - "transcribe", - "transcribe_fmllr", - "score_transcriptions", "fmllr_rescore_func", "final_fmllr_est_func", "initial_fmllr_func", @@ -45,10 +36,144 @@ ] +class CreateHclgArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.create_hclg_func`""" + + log_path: str + working_directory: str + path_template: str + words_path: str + carpa_path: str + small_arpa_path: str + medium_arpa_path: str + big_arpa_path: str + model_path: str + disambig_L_path: str + disambig_int_path: str + hclg_options: MetaDict + words_mapping: MappingType + + @property + def hclg_path(self) -> str: + return self.path_template.format(file_name="HCLG") + + +class DecodeArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.decode_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + decode_options: MetaDict + model_path: str + lat_paths: dict[str, str] + words_paths: dict[str, str] + hclg_paths: dict[str, str] + + +class ScoreArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.score_func`""" + + log_path: str + dictionaries: list[str] + score_options: MetaDict + lat_paths: dict[str, str] + rescored_lat_paths: dict[str, str] + carpa_rescored_lat_paths: dict[str, str] + words_paths: dict[str, str] + tra_paths: dict[str, str] + + +class LmRescoreArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.lm_rescore_func`""" + + log_path: str + dictionaries: list[str] + lm_rescore_options: MetaDict + lat_paths: dict[str, str] + rescored_lat_paths: dict[str, str] + old_g_paths: dict[str, str] + new_g_paths: dict[str, str] + + +class CarpaLmRescoreArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.carpa_lm_rescore_func`""" + + log_path: str + dictionaries: list[str] + lat_paths: dict[str, str] + rescored_lat_paths: dict[str, str] + old_g_paths: dict[str, str] + new_g_paths: dict[str, str] + + +class InitialFmllrArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.initial_fmllr_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + model_path: str + fmllr_options: MetaDict + pre_trans_paths: dict[str, str] + lat_paths: dict[str, str] + spk2utt_paths: dict[str, str] + + +class LatGenFmllrArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.lat_gen_fmllr_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + model_path: str + decode_options: MetaDict + words_paths: dict[str, str] + hclg_paths: dict[str, str] + tmp_lat_paths: dict[str, str] + + +class FinalFmllrArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.final_fmllr_est_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + model_path: str + fmllr_options: MetaDict + trans_paths: dict[str, str] + spk2utt_paths: dict[str, str] + tmp_lat_paths: dict[str, str] + + +class FmllrRescoreArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.fmllr_rescore_func`""" + + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + model_path: str + fmllr_options: MetaDict + tmp_lat_paths: dict[str, str] + final_lat_paths: dict[str, str] + + def compose_lg(dictionary_path: str, small_g_path: str, lg_path: str, log_file: TextIO) -> None: """ Compose an LG.fst + See Also + -------- + :kaldi_src:`fsttablecompose` + Relevant Kaldi binary + :kaldi_src:`fstdeterminizestar` + Relevant Kaldi binary + :kaldi_src:`fstminimizeencoded` + Relevant Kaldi binary + :kaldi_src:`fstpushspecial` + Relevant Kaldi binary + + Parameters ---------- dictionary_path: str @@ -114,6 +239,13 @@ def compose_clg( """ Compose a CLG.fst + See Also + -------- + :kaldi_src:`fstcomposecontext` + Relevant Kaldi binary + :openfst_src:`fstarcsort` + Relevant OpenFst binary + Parameters ---------- in_disambig: str @@ -166,6 +298,23 @@ def compose_hclg( """ Compost HCLG.fst for a dictionary + See Also + -------- + :kaldi_src:`make-h-transducer` + Relevant Kaldi binary + :kaldi_src:`fsttablecompose` + Relevant Kaldi binary + :kaldi_src:`fstdeterminizestar` + Relevant Kaldi binary + :kaldi_src:`fstrmsymbols` + Relevant Kaldi binary + :kaldi_src:`fstrmepslocal` + Relevant Kaldi binary + :kaldi_src:`fstminimizeencoded` + Relevant Kaldi binary + :openfst_src:`fstarcsort` + Relevant OpenFst binary + Parameters ---------- model_directory: str @@ -243,6 +392,11 @@ def compose_g(arpa_path: str, words_path: str, g_path: str, log_file: TextIO) -> """ Create G.fst from an ARPA formatted language model + See Also + -------- + :kaldi_src:`arpa2fst` + Relevant Kaldi binary + Parameters ---------- arpa_path: str @@ -278,13 +432,18 @@ def compose_g_carpa( """ Compose a large ARPA model into a G.carpa file + See Also + -------- + :kaldi_src:`arpa-to-const-arpa` + Relevant Kaldi binary + Parameters ---------- in_carpa_path: str Input ARPA model path temp_carpa_path: str Temporary CARPA model path - words_mapping: Dict[str, int] + words_mapping: dict[str, int] Words symbols mapping carpa_path: str Path to save output G.carpa @@ -371,6 +530,17 @@ def create_hclg_func( """ Create HCLG.fst file + See Also + -------- + :meth:`.Transcriber.create_hclgs` + Main function that calls this function in parallel + :meth:`.Transcriber.create_hclgs_arguments` + Job method for generating arguments for this function + :kaldi_src:`add-self-loops` + Relevant Kaldi binary + :openfst_src:`fstconvert` + Relevant OpenFst binary + Parameters ---------- log_path: str @@ -385,19 +555,19 @@ def create_hclg_func( Path to G.carpa file small_arpa_path: str Path to small ARPA file - medium_arpa_path: + medium_arpa_path: str Path to medium ARPA file - big_arpa_path: + big_arpa_path: str Path to big ARPA file model_path: str Path to acoustic model file - disambig_L_path: + disambig_L_path: str Path to L_disambig.fst file - disambig_int_path: + disambig_int_path: str Path to dictionary's disambiguation symbols file - hclg_options: :class:`~montreal_forced_aligner.abc.MetaDict` + hclg_options: dict[str, Any] Configuration options for composing HCLG.fst - words_mapping: Dict[str, int] + words_mapping: dict[str, int] Word labels to integer ID mapping """ hclg_path = path_template.format(file_name="HCLG") @@ -487,66 +657,45 @@ def create_hclg_func( log_file.write(f"There was an error in generating {hclg_path}") -def create_hclgs(transcriber: Transcriber): - """ - Create HCLG.fst files for every dictionary being used by a :class:`~montreal_forced_aligner.transcriber.Transcriber` - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber in use - """ - dict_arguments = {} - - for j in transcriber.corpus.jobs: - dict_arguments.update(j.create_hclgs_arguments(transcriber)) - dict_arguments = list(dict_arguments.values()) - if transcriber.transcribe_config.use_mp: - run_mp(create_hclg_func, dict_arguments, transcriber.working_log_directory) - else: - run_non_mp(create_hclg_func, dict_arguments, transcriber.working_log_directory) - error_logs = [] - for arg in dict_arguments: - if not os.path.exists(arg.hclg_path): - error_logs.append(arg.log_path) - else: - with open(arg.log_path, "r", encoding="utf8") as f: - for line in f: - transcriber.logger.warning(line) - if error_logs: - raise KaldiProcessingError(error_logs) - - def decode_func( log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], + dictionaries: list[str], + feature_strings: dict[str, str], decode_options: MetaDict, model_path: str, - lat_paths: Dict[str, str], - word_symbol_paths: Dict[str, str], - hclg_paths: Dict[str, str], + lat_paths: dict[str, str], + word_symbol_paths: dict[str, str], + hclg_paths: dict[str, str], ) -> None: """ Multiprocessing function for performing decoding + See Also + -------- + :meth:`.Transcriber.transcribe` + Main function that calls this function in parallel + :meth:`.Transcriber.decode_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-latgen-faster` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - feature_strings: Dict[str, str] + feature_strings: dict[str, str] Dictionary of feature strings per dictionary name - decode_options: :class:`~montreal_forced_aligner.abc.MetaDict` + decode_options: dict[str, Any] Options for decoding model_path: str Path to acoustic model file - lat_paths: Dict[str, str] + lat_paths: dict[str, str] Dictionary of lattice archive paths per dictionary name - word_symbol_paths: Dict[str, str] + word_symbol_paths: dict[str, str] Dictionary of word symbol paths per dictionary name - hclg_paths: Dict[str, str] + hclg_paths: dict[str, str] Dictionary of HCLG.fst paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -557,14 +706,16 @@ def decode_func( hclg_path = hclg_paths[dict_name] if os.path.exists(lat_path): continue - if decode_options["fmllr"] and decode_options["first_beam"] is not None: + if ( + decode_options["uses_speaker_adaptation"] + and decode_options["first_beam"] is not None + ): beam = decode_options["first_beam"] else: beam = decode_options["beam"] if ( - decode_options["fmllr"] + decode_options["uses_speaker_adaptation"] and decode_options["first_max_active"] is not None - and not decode_options["ignore_speakers"] ): max_active = decode_options["first_max_active"] else: @@ -591,34 +742,47 @@ def decode_func( def score_func( log_path: str, - dictionaries: List[str], + dictionaries: list[str], score_options: MetaDict, - lat_paths: Dict[str, str], - rescored_lat_paths: Dict[str, str], - carpa_rescored_lat_paths: Dict[str, str], - words_paths: Dict[str, str], - tra_paths: Dict[str, str], + lat_paths: dict[str, str], + rescored_lat_paths: dict[str, str], + carpa_rescored_lat_paths: dict[str, str], + words_paths: dict[str, str], + tra_paths: dict[str, str], ) -> None: """ Multiprocessing function for scoring lattices + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.score_transcriptions` + Main function that calls this function in parallel + :meth:`.Transcriber.score_arguments` + Job method for generating arguments for this function + :kaldi_src:`lattice-scale` + Relevant Kaldi binary + :kaldi_src:`lattice-add-penalty` + Relevant Kaldi binary + :kaldi_src:`lattice-best-path` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - score_options: :class:`~montreal_forced_aligner.abc.MetaDict` + score_options: dict[str, Any] Options for scoring - lat_paths: Dict[str, str] + lat_paths: dict[str, str] Dictionary of lattice archive paths per dictionary name - rescored_lat_paths: Dict[str, str] + rescored_lat_paths: dict[str, str] Dictionary of medium G.fst rescored lattice archive paths per dictionary name - carpa_rescored_lat_paths: Dict[str, str] + carpa_rescored_lat_paths: dict[str, str] Dictionary of carpa-rescored lattice archive paths per dictionary name - words_paths: Dict[str, str] + words_paths: dict[str, str] Dictionary of word symbol paths per dictionary name - tra_paths: Dict[str, str] + tra_paths: dict[str, str] Dictionary of transcription archive paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -673,31 +837,42 @@ def score_func( def lm_rescore_func( log_path: str, - dictionaries: List[str], + dictionaries: list[str], lm_rescore_options: MetaDict, - lat_paths: Dict[str, str], - rescored_lat_paths: Dict[str, str], - old_g_paths: Dict[str, str], - new_g_paths: Dict[str, str], + lat_paths: dict[str, str], + rescored_lat_paths: dict[str, str], + old_g_paths: dict[str, str], + new_g_paths: dict[str, str], ) -> None: """ Multiprocessing function rescore lattices by replacing the small G.fst with the medium G.fst + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.transcribe` + Main function that calls this function in parallel + :meth:`.Transcriber.lm_rescore_arguments` + Job method for generating arguments for this function + :kaldi_src:`lattice-lmrescore-pruned` + Relevant Kaldi binary + :openfst_src:`fstproject` + Relevant OpenFst binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - lm_rescore_options: :class:`~montreal_forced_aligner.abc.MetaDict` + lm_rescore_options: dict[str, Any] Options for rescoring - lat_paths: Dict[str, str] + lat_paths: dict[str, str] Dictionary of lattice archive paths per dictionary name - rescored_lat_paths: Dict[str, str] + rescored_lat_paths: dict[str, str] Dictionary of rescored lattice archive paths per dictionary name - old_g_paths: Dict[str, str] + old_g_paths: dict[str, str] Dictionary of small G.fst paths per dictionary name - new_g_paths: Dict[str, str] + new_g_paths: dict[str, str] Dictionary of medium G.fst paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -713,15 +888,22 @@ def lm_rescore_func( if os.path.exists(rescored_lat_path): continue + project_proc = subprocess.Popen( + [thirdparty_binary("fstproject"), project_type_arg, old_g_path], + stdout=subprocess.PIPE, + stderr=log_file, + env=os.environ, + ) lattice_scale_proc = subprocess.Popen( [ thirdparty_binary("lattice-lmrescore-pruned"), f"--acoustic-scale={lm_rescore_options['acoustic_scale']}", - f"fstproject {project_type_arg} {old_g_path} |", + "-", f"fstproject {project_type_arg} {new_g_path} |", f"ark:{lat_path}", f"ark:{rescored_lat_path}", ], + stdin=project_proc.stdout, stderr=log_file, env=os.environ, ) @@ -730,29 +912,42 @@ def lm_rescore_func( def carpa_lm_rescore_func( log_path: str, - dictionaries: List[str], - lat_paths: Dict[str, str], - rescored_lat_paths: Dict[str, str], - old_g_paths: Dict[str, str], - new_g_paths: Dict[str, str], + dictionaries: list[str], + lat_paths: dict[str, str], + rescored_lat_paths: dict[str, str], + old_g_paths: dict[str, str], + new_g_paths: dict[str, str], ) -> None: """ Multiprocessing function to rescore lattices by replacing medium G.fst with large G.carpa + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.transcribe` + Main function that calls this function in parallel + :meth:`.Transcriber.carpa_lm_rescore_arguments` + Job method for generating arguments for this function + :openfst_src:`fstproject` + Relevant OpenFst binary + :kaldi_src:`lattice-lmrescore` + Relevant Kaldi binary + :kaldi_src:`lattice-lmrescore-const-arpa` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - lat_paths: Dict[str, str] - PronunciationDictionary of lattice archive paths per dictionary name - rescored_lat_paths: Dict[str, str] - PronunciationDictionary of rescored lattice archive paths per dictionary name - old_g_paths: Dict[str, str] - PronunciationDictionary of medium G.fst paths per dictionary name - new_g_paths: Dict[str, str] - PronunciationDictionary of large G.carpa paths per dictionary name + lat_paths: dict[str, str] + Dictionary of lattice archive paths per dictionary name + rescored_lat_paths: dict[str, str] + Dictionary of rescored lattice archive paths per dictionary name + old_g_paths: dict[str, str] + Dictionary of medium G.fst paths per dictionary name + new_g_paths: dict[str, str] + Dictionary of large G.carpa paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: for dict_name in dictionaries: @@ -800,114 +995,51 @@ def carpa_lm_rescore_func( lmrescore_const_proc.communicate() -def transcribe(transcriber: Transcriber) -> None: - """ - Transcribe a corpus using a Transcriber - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - """ - config = transcriber.transcribe_config - - jobs = [x.decode_arguments(transcriber) for x in transcriber.corpus.jobs] - - if config.use_mp: - run_mp(decode_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(decode_func, jobs, transcriber.working_log_directory) - - jobs = [x.lm_rescore_arguments(transcriber) for x in transcriber.corpus.jobs] - - if config.use_mp: - run_mp(lm_rescore_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(lm_rescore_func, jobs, transcriber.working_log_directory) - - jobs = [x.carpa_lm_rescore_arguments(transcriber) for x in transcriber.corpus.jobs] - - if config.use_mp: - run_mp(carpa_lm_rescore_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(carpa_lm_rescore_func, jobs, transcriber.working_log_directory) - - -def score_transcriptions(transcriber: Transcriber): - """ - Score transcriptions if reference text is available in the corpus - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - """ - if transcriber.evaluation_mode: - best_wer = 10000 - best = None - for lmwt in range( - transcriber.min_language_model_weight, transcriber.max_language_model_weight - ): - for wip in transcriber.word_insertion_penalties: - transcriber.transcribe_config.language_model_weight = lmwt - transcriber.transcribe_config.word_insertion_penalty = wip - os.makedirs(transcriber.evaluation_log_directory, exist_ok=True) - - jobs = [x.score_arguments(transcriber) for x in transcriber.corpus.jobs] - if transcriber.transcribe_config.use_mp: - run_mp(score_func, jobs, transcriber.evaluation_log_directory) - else: - run_non_mp(score_func, jobs, transcriber.evaluation_log_directory) - ser, wer = transcriber.evaluate() - if wer < best_wer: - best = (lmwt, wip) - transcriber.transcribe_config.language_model_weight = best[0] - transcriber.transcribe_config.word_insertion_penalty = best[1] - for j in transcriber.corpus.jobs: - score_args = j.score_arguments(transcriber) - for p in score_args.tra_paths.values(): - shutil.copyfile( - p, - p.replace(transcriber.evaluation_directory, transcriber.transcribe_directory), - ) - else: - jobs = [x.score_arguments(transcriber) for x in transcriber.corpus.jobs] - if transcriber.transcribe_config.use_mp: - run_mp(score_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(score_func, jobs, transcriber.working_log_directory) - - def initial_fmllr_func( log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], + dictionaries: list[str], + feature_strings: dict[str, str], model_path: str, fmllr_options: MetaDict, - trans_paths: Dict[str, str], - lat_paths: Dict[str, str], - spk2utt_paths: Dict[str, str], + trans_paths: dict[str, str], + lat_paths: dict[str, str], + spk2utt_paths: dict[str, str], ) -> None: """ Multiprocessing function for running initial fMLLR calculation + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.transcribe_fmllr` + Main function that calls this function in parallel + :meth:`.Transcriber.initial_fmllr_arguments` + Job method for generating arguments for this function + :kaldi_src:`lattice-to-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + :kaldi_src:`gmm-post-to-gpost` + Relevant Kaldi binary + :kaldi_src:`gmm-est-fmllr-gpost` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - feature_strings: Dict[str, str] + feature_strings: dict[str, str] Dictionary of feature strings per dictionary name model_path: str Path to acoustic model file - fmllr_options: :class:`~montreal_forced_aligner.abc.MetaDict` + fmllr_options: dict[str, Any] Options for calculating fMLLR transforms - trans_paths: Dict[str, str] + trans_paths: dict[str, str] Dictionary of transform archives per dictionary name - lat_paths: Dict[str, str] + lat_paths: dict[str, str] Dictionary of lattice archive paths per dictionary name - spk2utt_paths: Dict[str, str] + spk2utt_paths: dict[str, str] Dictionary of spk2utt scp files per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -975,34 +1107,43 @@ def initial_fmllr_func( def lat_gen_fmllr_func( log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], + dictionaries: list[str], + feature_strings: dict[str, str], model_path: str, decode_options: MetaDict, - word_symbol_paths: Dict[str, str], - hclg_paths: Dict[str, str], - tmp_lat_paths: Dict[str, str], + word_symbol_paths: dict[str, str], + hclg_paths: dict[str, str], + tmp_lat_paths: dict[str, str], ) -> None: """ Regenerate lattices using initial fMLLR transforms + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.transcribe_fmllr` + Main function that calls this function in parallel + :meth:`.Transcriber.lat_gen_fmllr_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-latgen-faster` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - feature_strings: Dict[str, str] + feature_strings: dict[str, str] Dictionary of feature strings per dictionary name model_path: str Path to acoustic model file - decode_options: :class:`~montreal_forced_aligner.abc.MetaDict` + decode_options: dict[str, Any] Options for decoding - word_symbol_paths: Dict[str, str] + word_symbol_paths: dict[str, str] Dictionary of word symbol paths per dictionary name - hclg_paths: Dict[str, str] + hclg_paths: dict[str, str] Dictionary of HCLG.fst paths per dictionary name - tmp_lat_paths: Dict[str, str] + tmp_lat_paths: dict[str, str] Dictionary of temporary lattice archive paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -1035,34 +1176,51 @@ def lat_gen_fmllr_func( def final_fmllr_est_func( log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], + dictionaries: list[str], + feature_strings: dict[str, str], model_path: str, fmllr_options: MetaDict, - trans_paths: Dict[str, str], - spk2utt_paths: Dict[str, str], - tmp_lat_paths: Dict[str, str], + trans_paths: dict[str, str], + spk2utt_paths: dict[str, str], + tmp_lat_paths: dict[str, str], ) -> None: """ Multiprocessing function for running final fMLLR estimation + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.transcribe_fmllr` + Main function that calls this function in parallel + :meth:`.Transcriber.final_fmllr_arguments` + Job method for generating arguments for this function + :kaldi_src:`lattice-determinize-pruned` + Relevant Kaldi binary + :kaldi_src:`lattice-to-post` + Relevant Kaldi binary + :kaldi_src:`weight-silence-post` + Relevant Kaldi binary + :kaldi_src:`gmm-est-fmllr` + Relevant Kaldi binary + :kaldi_src:`compose-transforms` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - feature_strings: Dict[str, str] + feature_strings: dict[str, str] Dictionary of feature strings per dictionary name model_path: str Path to acoustic model file - fmllr_options: :class:`~montreal_forced_aligner.abc.MetaDict` + fmllr_options: dict[str, Any] Options for calculating fMLLR transforms - trans_paths: Dict[str, str] + trans_paths: dict[str, str] Dictionary of transform archives per dictionary name - spk2utt_paths: Dict[str, str] + spk2utt_paths: dict[str, str] Dictionary of spk2utt scp files per dictionary name - tmp_lat_paths: Dict[str, str] + tmp_lat_paths: dict[str, str] Dictionary of temporary lattice archive paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -1146,31 +1304,42 @@ def final_fmllr_est_func( def fmllr_rescore_func( log_path: str, - dictionaries: List[str], - feature_strings: Dict[str, str], + dictionaries: list[str], + feature_strings: dict[str, str], model_path: str, fmllr_options: MetaDict, - tmp_lat_paths: Dict[str, str], - final_lat_paths: Dict[str, str], + tmp_lat_paths: dict[str, str], + final_lat_paths: dict[str, str], ) -> None: """ Multiprocessing function to rescore lattices following fMLLR estimation + See Also + -------- + :func:`~montreal_forced_aligner.transcription.Transcriber.transcribe_fmllr` + Main function that calls this function in parallel + :meth:`.Transcriber.fmllr_rescore_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-rescore-lattice` + Relevant Kaldi binary + :kaldi_src:`lattice-determinize-pruned` + Relevant Kaldi binary + Parameters ---------- log_path: str Path to save log output - dictionaries: List[str] + dictionaries: list[str] List of dictionary names - feature_strings: Dict[str, str] + feature_strings: dict[str, str] Dictionary of feature strings per dictionary name model_path: str Path to acoustic model file - fmllr_options: :class:`~montreal_forced_aligner.abc.MetaDict` + fmllr_options: dict[str, Any] Options for calculating fMLLR transforms - tmp_lat_paths: Dict[str, str] + tmp_lat_paths: dict[str, str] Dictionary of temporary lattice archive paths per dictionary name - final_lat_paths: Dict[str, str] + final_lat_paths: dict[str, str] Dictionary of lattice archive paths per dictionary name """ with open(log_path, "w", encoding="utf8") as log_file: @@ -1204,47 +1373,3 @@ def fmllr_rescore_func( ) determinize_proc.communicate() - - -def transcribe_fmllr(transcriber: Transcriber) -> None: - """ - Run fMLLR estimation over initial decoding lattices and rescore - - Parameters - ---------- - transcriber: :class:`~montreal_forced_aligner.transcriber.Transcriber` - Transcriber - """ - jobs = [x.initial_fmllr_arguments(transcriber) for x in transcriber.corpus.jobs] - - run_non_mp(initial_fmllr_func, jobs, transcriber.working_log_directory) - transcriber.speaker_independent = False - - jobs = [x.lat_gen_fmllr_arguments(transcriber) for x in transcriber.corpus.jobs] - - run_non_mp(lat_gen_fmllr_func, jobs, transcriber.working_log_directory) - - jobs = [x.final_fmllr_arguments(transcriber) for x in transcriber.corpus.jobs] - - run_non_mp(final_fmllr_est_func, jobs, transcriber.working_log_directory) - - jobs = [x.fmllr_rescore_arguments(transcriber) for x in transcriber.corpus.jobs] - - if transcriber.transcribe_config.use_mp: - run_mp(fmllr_rescore_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(fmllr_rescore_func, jobs, transcriber.working_log_directory) - - jobs = [x.lm_rescore_arguments(transcriber) for x in transcriber.corpus.jobs] - - if transcriber.transcribe_config.use_mp: - run_mp(lm_rescore_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(lm_rescore_func, jobs, transcriber.working_log_directory) - - jobs = [x.carpa_lm_rescore_arguments(transcriber) for x in transcriber.corpus.jobs] - - if transcriber.transcribe_config.use_mp: - run_mp(carpa_lm_rescore_func, jobs, transcriber.working_log_directory) - else: - run_non_mp(carpa_lm_rescore_func, jobs, transcriber.working_log_directory) diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py new file mode 100644 index 00000000..6957fa72 --- /dev/null +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -0,0 +1,935 @@ +""" +Transcription +============= + +""" +from __future__ import annotations + +import multiprocessing as mp +import os +import shutil +import subprocess +import sys +import time +from abc import abstractmethod +from typing import TYPE_CHECKING, Optional + +import yaml + +from ..abc import FileExporterMixin, TopLevelMfaWorker +from ..corpus.acoustic_corpus import AcousticCorpusPronunciationMixin +from ..exceptions import KaldiProcessingError, PlatformError +from ..helper import parse_old_features, score +from ..models import AcousticModel, LanguageModel +from ..utils import log_kaldi_errors, run_mp, run_non_mp, thirdparty_binary +from .multiprocessing import ( + CarpaLmRescoreArguments, + CreateHclgArguments, + DecodeArguments, + FinalFmllrArguments, + FmllrRescoreArguments, + InitialFmllrArguments, + LatGenFmllrArguments, + LmRescoreArguments, + ScoreArguments, + carpa_lm_rescore_func, + create_hclg_func, + decode_func, + final_fmllr_est_func, + fmllr_rescore_func, + initial_fmllr_func, + lat_gen_fmllr_func, + lm_rescore_func, + score_func, +) + +if TYPE_CHECKING: + from argparse import Namespace + + from ..abc import MetaDict + +__all__ = ["Transcriber", "TranscriberMixin"] + + +class TranscriberMixin: + """Abstract class for MFA transcribers + + Parameters + ---------- + transition_scale: float + Transition scale, defaults to 1.0 + acoustic_scale: float + Acoustic scale, defaults to 0.1 + self_loop_scale: float + Self-loop scale, defaults to 0.1 + beam: int + Size of the beam to use in decoding, defaults to 10 + silence_weight: float + Weight on silence in fMLLR estimation + max_active: int + Max active for decoding + lattice_beam: int + Beam width for decoding lattices + first_beam: int + Beam for decoding in initial speaker-independent pass, only used if ``uses_speaker_adaptation`` is true + first_max_active: int + Max active for decoding in initial speaker-independent pass, only used if ``uses_speaker_adaptation`` is true + language_model_weight: float + Weight of language model + word_insertion_penalty: float + Penalty for inserting words + """ + + def __init__( + self, + transition_scale: float = 1.0, + acoustic_scale: float = 0.083333, + self_loop_scale: float = 0.1, + beam: int = 10, + silence_weight: float = 0.01, + max_active: int = 7000, + lattice_beam: int = 6, + first_beam: int = 0, + first_max_active: int = 2000, + language_model_weight: int = 10, + word_insertion_penalty: float = 0.5, + **kwargs, + ): + super().__init__(**kwargs) + + self.beam = beam + self.acoustic_scale = acoustic_scale + self.self_loop_scale = self_loop_scale + self.transition_scale = transition_scale + self.silence_weight = silence_weight + self.max_active = max_active + self.lattice_beam = lattice_beam + self.first_beam = first_beam + self.first_max_active = first_max_active + self.language_model_weight = language_model_weight + self.word_insertion_penalty = word_insertion_penalty + + @abstractmethod + def create_decoding_graph(self) -> None: + """Create decoding graph for use in transcription""" + ... + + @abstractmethod + def transcribe(self) -> None: + """Perform transcription""" + ... + + @property + @abstractmethod + def model_path(self) -> str: + """Acoustic model file path""" + ... + + @property + def decode_options(self) -> MetaDict: + """Options needed for decoding""" + return { + "first_beam": self.first_beam, + "beam": self.beam, + "first_max_active": self.first_max_active, + "max_active": self.max_active, + "lattice_beam": self.lattice_beam, + "acoustic_scale": self.acoustic_scale, + "uses_speaker_adaptation": self.uses_speaker_adaptation, + } + + @property + def score_options(self) -> MetaDict: + """Options needed for scoring lattices""" + return { + "language_model_weight": self.language_model_weight, + "word_insertion_penalty": self.word_insertion_penalty, + } + + @property + def transcribe_fmllr_options(self) -> MetaDict: + """Options needed for calculating fMLLR transformations""" + return { + "acoustic_scale": self.acoustic_scale, + "silence_weight": self.silence_weight, + "lattice_beam": self.lattice_beam, + } + + @property + def lm_rescore_options(self) -> MetaDict: + """Options needed for rescoring the language model""" + return { + "acoustic_scale": self.acoustic_scale, + } + + +class Transcriber( + AcousticCorpusPronunciationMixin, TranscriberMixin, FileExporterMixin, TopLevelMfaWorker +): + """ + Class for performing transcription. + + Parameters + ---------- + acoustic_model_path : str + Path to acoustic model + language_model_path : str + Path to language model model + evaluation_mode: bool + Flag for evaluating generated transcripts against the actual transcripts, defaults to False + min_language_model_weight: int + Minimum language model weight to use in evaluation mode, defaults to 7 + max_language_model_weight: int + Maximum language model weight to use in evaluation mode, defaults to 17 + word_insertion_penalties: list[float] + List of word insertion penalties to use in evaluation mode, defaults to [0, 0.5, 1.0] + + See Also + -------- + :class:`~montreal_forced_aligner.transcription.transcriber.TranscriberMixin` + For transcription parameters + :class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin` + For corpus and dictionary parsing parameters + :class:`~montreal_forced_aligner.abc.FileExporterMixin` + For file exporting parameters + :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` + For top-level parameters + + Attributes + ---------- + acoustic_model: AcousticModel + Acoustic model + language_model: LanguageModel + Language model + """ + + def __init__( + self, + acoustic_model_path: str, + language_model_path: str, + evaluation_mode: bool = False, + min_language_model_weight: int = 7, + max_language_model_weight: int = 17, + word_insertion_penalties: list[float] = None, + **kwargs, + ): + self.acoustic_model = AcousticModel(acoustic_model_path) + kwargs.update(self.acoustic_model.parameters) + super(Transcriber, self).__init__(**kwargs) + self.language_model = LanguageModel(language_model_path, self.model_directory) + if word_insertion_penalties is None: + word_insertion_penalties = [0, 0.5, 1.0] + self.min_language_model_weight = min_language_model_weight + self.max_language_model_weight = max_language_model_weight + self.evaluation_mode = evaluation_mode + self.word_insertion_penalties = word_insertion_penalties + + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + """ + Parse configuration parameters from a config file and command line arguments + + Parameters + ---------- + config_path: str, optional + Path to yaml configuration file + args: :class:`~argparse.Namespace`, optional + Arguments parsed by argparse + unknown_args: list[str], optional + List of unknown arguments from argparse + + Returns + ------- + dict[str, Any] + Dictionary of specified configuration parameters + """ + global_params = {} + if config_path and os.path.exists(config_path): + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + data = parse_old_features(data) + for k, v in data.items(): + if k == "features": + global_params.update(v) + else: + global_params[k] = v + global_params.update(cls.parse_args(args, unknown_args)) + return global_params + + def setup(self) -> None: + """Set up transcription""" + if self.initialized: + return + begin = time.time() + os.makedirs(self.working_log_directory, exist_ok=True) + check = self.check_previous_run() + if check: + self.logger.debug( + "There were some differences in the current run compared to the last one. " + "This may cause issues, run with --clean, if you hit an error." + ) + self.load_corpus() + self.acoustic_model.validate(self) + self.acoustic_model.export_model(self.model_directory) + self.acoustic_model.export_model(self.working_directory) + self.acoustic_model.log_details(self.logger) + self.create_decoding_graph() + self.initialized = True + self.logger.debug(f"Setup for transcription in {time.time() - begin} seconds") + + def create_hclgs_arguments(self) -> dict[str, CreateHclgArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.create_hclg_func` + + Returns + ------- + dict[str, :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgArguments`] + Per dictionary arguments for HCLG + """ + args = {} + for dict_name, dictionary in self.dictionary_mapping.items(): + args[dict_name] = CreateHclgArguments( + os.path.join(self.model_directory, "log", f"hclg.{dict_name}.log"), + self.model_directory, + os.path.join(self.model_directory, "{file_name}" + f".{dict_name}.fst"), + os.path.join(self.model_directory, f"words.{dict_name}.txt"), + os.path.join(self.model_directory, f"G.{dict_name}.carpa"), + self.language_model.small_arpa_path, + self.language_model.medium_arpa_path, + self.language_model.carpa_path, + self.model_path, + dictionary.lexicon_disambig_fst_path, + os.path.join(dictionary.phones_dir, "disambiguation_symbols.int"), + self.hclg_options, + dictionary.words_mapping, + ) + return args + + def decode_arguments(self) -> list[DecodeArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.decode_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments`] + Arguments for processing + """ + feat_string = self.construct_feature_proc_strings() + return [ + DecodeArguments( + os.path.join(self.working_log_directory, f"decode.{j.name}.log"), + j.current_dictionary_names, + feat_string[j.name], + self.decode_options, + self.alignment_model_path, + j.construct_path_dictionary(self.working_directory, "lat", "ark"), + j.construct_dictionary_dependent_paths(self.model_directory, "words", "txt"), + j.construct_dictionary_dependent_paths(self.model_directory, "HCLG", "fst"), + ) + for j in self.jobs + ] + + def score_arguments(self) -> list[ScoreArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.score_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.ScoreArguments`] + Arguments for processing + """ + return [ + ScoreArguments( + os.path.join(self.working_log_directory, f"score.{j.name}.log"), + j.current_dictionary_names, + self.score_options, + j.construct_path_dictionary(self.working_directory, "lat", "ark"), + j.construct_path_dictionary(self.working_directory, "lat.rescored", "ark"), + j.construct_path_dictionary(self.working_directory, "lat.carpa.rescored", "ark"), + j.construct_dictionary_dependent_paths(self.model_directory, "words", "txt"), + j.construct_path_dictionary(self.evaluation_directory, "tra", "scp"), + ) + for j in self.jobs + ] + + def lm_rescore_arguments(self) -> list[LmRescoreArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.lm_rescore_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreArguments`] + Arguments for processing + """ + return [ + LmRescoreArguments( + os.path.join(self.working_log_directory, f"lm_rescore.{j.name}.log"), + j.current_dictionary_names, + self.lm_rescore_options, + j.construct_path_dictionary(self.working_directory, "lat", "ark"), + j.construct_path_dictionary(self.working_directory, "lat.rescored", "ark"), + j.construct_dictionary_dependent_paths(self.model_directory, "G.small", "fst"), + j.construct_dictionary_dependent_paths(self.model_directory, "G.med", "fst"), + ) + for j in self.jobs + ] + + def carpa_lm_rescore_arguments(self) -> list[CarpaLmRescoreArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.carpa_lm_rescore_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreArguments`] + Arguments for processing + """ + return [ + CarpaLmRescoreArguments( + os.path.join(self.working_log_directory, f"carpa_lm_rescore.{j.name}.log"), + j.current_dictionary_names, + j.construct_path_dictionary(self.working_directory, "lat.rescored", "ark"), + j.construct_path_dictionary(self.working_directory, "lat.carpa.rescored", "ark"), + j.construct_dictionary_dependent_paths(self.model_directory, "G.med", "fst"), + j.construct_dictionary_dependent_paths(self.model_directory, "G", "carpa"), + ) + for j in self.jobs + ] + + @property + def fmllr_options(self) -> MetaDict: + """Options for calculating fMLLR""" + options = super().fmllr_options + options["acoustic_scale"] = self.acoustic_scale + options["sil_phones"] = self.silence_csl + options["lattice_beam"] = self.lattice_beam + return options + + def initial_fmllr_arguments(self) -> list[InitialFmllrArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.initial_fmllr_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + InitialFmllrArguments( + os.path.join(self.working_log_directory, f"initial_fmllr.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.model_path, + self.fmllr_options, + j.construct_path_dictionary(self.working_directory, "trans", "ark"), + j.construct_path_dictionary(self.working_directory, "lat", "ark"), + j.construct_path_dictionary(self.data_directory, "spk2utt", "scp"), + ) + for j in self.jobs + ] + + def lat_gen_fmllr_arguments(self) -> list[LatGenFmllrArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.lat_gen_fmllr_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.LatGenFmllrArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + LatGenFmllrArguments( + os.path.join(self.working_log_directory, f"lat_gen_fmllr.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.model_path, + self.decode_options, + j.construct_dictionary_dependent_paths(self.model_directory, "words", "txt"), + j.construct_dictionary_dependent_paths(self.model_directory, "HCLG", "fst"), + j.construct_path_dictionary(self.working_directory, "lat.tmp", "ark"), + ) + for j in self.jobs + ] + + def final_fmllr_arguments(self) -> list[FinalFmllrArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.final_fmllr_est_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + FinalFmllrArguments( + os.path.join(self.working_log_directory, f"final_fmllr.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.model_path, + self.fmllr_options, + j.construct_path_dictionary(self.working_directory, "trans", "ark"), + j.construct_path_dictionary(self.data_directory, "spk2utt", "scp"), + j.construct_path_dictionary(self.working_directory, "lat.tmp", "ark"), + ) + for j in self.jobs + ] + + def fmllr_rescore_arguments(self) -> list[FmllrRescoreArguments]: + """ + Generate Job arguments for :func:`~montreal_forced_aligner.transcription.multiprocessing.fmllr_rescore_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + return [ + FmllrRescoreArguments( + os.path.join(self.working_log_directory, f"fmllr_rescore.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + self.model_path, + self.fmllr_options, + j.construct_path_dictionary(self.working_directory, "lat.tmp", "ark"), + j.construct_path_dictionary(self.working_directory, "lat", "ark"), + ) + for j in self.jobs + ] + + @property + def workflow_identifier(self) -> str: + """Transcriber identifier""" + return "transcriber" + + @property + def evaluation_directory(self): + """Evaluation directory path for the current language model weight and word insertion penalty""" + eval_string = f"eval_{self.language_model_weight}_{self.word_insertion_penalty}" + path = os.path.join(self.working_directory, eval_string) + os.makedirs(path, exist_ok=True) + return path + + @property + def evaluation_log_directory(self) -> str: + """Log directory for the current evaluation""" + return os.path.join(self.evaluation_directory, "log") + + @property + def model_directory(self) -> str: + """Model directory for the transcriber""" + return os.path.join(self.output_directory, "models") + + @property + def model_path(self) -> str: + """Acoustic model file path""" + return os.path.join(self.working_directory, "final.mdl") + + @property + def alignment_model_path(self) -> str: + """Alignment (speaker-independent) acoustic model file path""" + path = os.path.join(self.working_directory, "final.alimdl") + if os.path.exists(path): + return path + return self.model_path + + @property + def hclg_options(self): + """Options for constructing HCLG FSTs""" + context_width, central_pos = self.get_tree_info() + return { + "context_width": context_width, + "central_pos": central_pos, + "self_loop_scale": self.self_loop_scale, + "transition_scale": self.transition_scale, + } + + def get_tree_info(self) -> tuple[int, int]: + """ + Get the context width and central position for the acoustic model + + Returns + ------- + int + Context width + int + Central position + """ + tree_proc = subprocess.Popen( + [thirdparty_binary("tree-info"), os.path.join(self.model_directory, "tree")], + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, _ = tree_proc.communicate() + context_width = 1 + central_pos = 0 + for line in stdout.split("\n"): + text = line.strip().split(" ") + if text[0] == "context-width": + context_width = int(text[1]) + elif text[0] == "central-position": + central_pos = int(text[1]) + return context_width, central_pos + + def create_hclgs(self): + """ + Create HCLG.fst files for every dictionary being used by a :class:`~montreal_forced_aligner.transcription.transcriber.Transcriber` + """ + dict_arguments = self.create_hclgs_arguments() + + dict_arguments = list(dict_arguments.values()) + if self.use_mp: + run_mp(create_hclg_func, dict_arguments, self.working_log_directory) + else: + run_non_mp(create_hclg_func, dict_arguments, self.working_log_directory) + error_logs = [] + for arg in dict_arguments: + if not os.path.exists(arg.hclg_path): + error_logs.append(arg.log_path) + if error_logs: + raise KaldiProcessingError(error_logs) + + def create_decoding_graph(self) -> None: + """ + Create decoding graph for use in transcription + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + dirty_path = os.path.join(self.model_directory, "dirty") + + if os.path.exists(dirty_path): # if there was an error, let's redo from scratch + shutil.rmtree(self.model_directory) + log_dir = os.path.join(self.model_directory, "log") + os.makedirs(log_dir, exist_ok=True) + self.write_lexicon_information(write_disambiguation=True) + for dict_name, dictionary in self.dictionary_mapping.items(): + words_path = os.path.join(self.model_directory, f"words.{dict_name}.txt") + shutil.copyfile(dictionary.words_symbol_path, words_path) + + big_arpa_path = self.language_model.carpa_path + small_arpa_path = self.language_model.small_arpa_path + medium_arpa_path = self.language_model.medium_arpa_path + if not os.path.exists(small_arpa_path) or not os.path.exists(medium_arpa_path): + self.logger.warning( + "Creating small and medium language models from scratch, this may take some time. " + "Running `mfa train_lm` on the ARPA file will remove this warning." + ) + if sys.platform == "win32": + raise PlatformError("ngram") + self.logger.info("Parsing large ngram model...") + mod_path = os.path.join(self.model_directory, "base_lm.mod") + new_carpa_path = os.path.join(self.model_directory, "base_lm.arpa") + with open(big_arpa_path, "r", encoding="utf8") as inf, open( + new_carpa_path, "w", encoding="utf8" + ) as outf: + for line in inf: + outf.write(line.lower()) + big_arpa_path = new_carpa_path + subprocess.call(["ngramread", "--ARPA", big_arpa_path, mod_path]) + + if not os.path.exists(small_arpa_path): + self.logger.info( + "Generating small model from the large ARPA with a pruning threshold of 3e-7" + ) + prune_thresh_small = 0.0000003 + small_mod_path = mod_path.replace(".mod", "_small.mod") + subprocess.call( + [ + "ngramshrink", + "--method=relative_entropy", + f"--theta={prune_thresh_small}", + mod_path, + small_mod_path, + ] + ) + subprocess.call(["ngramprint", "--ARPA", small_mod_path, small_arpa_path]) + + if not os.path.exists(medium_arpa_path): + self.logger.info( + "Generating medium model from the large ARPA with a pruning threshold of 1e-7" + ) + prune_thresh_medium = 0.0000001 + med_mod_path = mod_path.replace(".mod", "_med.mod") + subprocess.call( + [ + "ngramshrink", + "--method=relative_entropy", + f"--theta={prune_thresh_medium}", + mod_path, + med_mod_path, + ] + ) + subprocess.call(["ngramprint", "--ARPA", med_mod_path, medium_arpa_path]) + try: + self.create_hclgs() + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + def score_transcriptions(self): + """ + Score transcriptions if reference text is available in the corpus + + See Also + -------- + :func:`~montreal_forced_aligner.transcription.multiprocessing.score_func` + Multiprocessing helper function for each job + :meth:`.Transcriber.score_arguments` + Job method for generating arguments for this function + + """ + if self.evaluation_mode: + best_wer = 10000 + best = None + for lmwt in range(self.min_language_model_weight, self.max_language_model_weight): + for wip in self.word_insertion_penalties: + self.language_model_weight = lmwt + self.word_insertion_penalty = wip + os.makedirs(self.evaluation_log_directory, exist_ok=True) + + jobs = self.score_arguments() + if self.use_mp: + run_mp(score_func, jobs, self.evaluation_log_directory) + else: + run_non_mp(score_func, jobs, self.evaluation_log_directory) + ser, wer = self.evaluate() + if wer < best_wer: + best = (lmwt, wip) + self.language_model_weight = best[0] + self.word_insertion_penalty = best[1] + for score_args in self.score_arguments(): + for p in score_args.tra_paths.values(): + shutil.copyfile( + p, + p.replace(self.evaluation_directory, self.working_directory), + ) + else: + jobs = self.score_arguments() + if self.use_mp: + run_mp(score_func, jobs, self.working_log_directory) + else: + run_non_mp(score_func, jobs, self.working_log_directory) + + def transcribe_fmllr(self) -> None: + """ + Run fMLLR estimation over initial decoding lattices and rescore + + See Also + -------- + :func:`~montreal_forced_aligner.transcription.multiprocessing.initial_fmllr_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.lat_gen_fmllr_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.final_fmllr_est_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.fmllr_rescore_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.lm_rescore_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.carpa_lm_rescore_func` + Multiprocessing helper function for each job + + """ + jobs = self.initial_fmllr_arguments() + + if self.use_mp: + run_mp(initial_fmllr_func, jobs, self.working_log_directory) + else: + run_non_mp(initial_fmllr_func, jobs, self.working_log_directory) + + self.speaker_independent = False + + jobs = self.lat_gen_fmllr_arguments() + + if self.use_mp: + run_mp(lat_gen_fmllr_func, jobs, self.working_log_directory) + else: + run_non_mp(lat_gen_fmllr_func, jobs, self.working_log_directory) + + jobs = self.final_fmllr_arguments() + + if self.use_mp: + run_mp(final_fmllr_est_func, jobs, self.working_log_directory) + else: + run_non_mp(final_fmllr_est_func, jobs, self.working_log_directory) + + jobs = self.fmllr_rescore_arguments() + + if self.use_mp: + run_mp(fmllr_rescore_func, jobs, self.working_log_directory) + else: + run_non_mp(fmllr_rescore_func, jobs, self.working_log_directory) + + jobs = self.lm_rescore_arguments() + + if self.use_mp: + run_mp(lm_rescore_func, jobs, self.working_log_directory) + else: + run_non_mp(lm_rescore_func, jobs, self.working_log_directory) + + jobs = self.carpa_lm_rescore_arguments() + + if self.use_mp: + run_mp(carpa_lm_rescore_func, jobs, self.working_log_directory) + else: + run_non_mp(carpa_lm_rescore_func, jobs, self.working_log_directory) + + def transcribe(self) -> None: + """ + Transcribe the corpus + + See Also + -------- + :func:`~montreal_forced_aligner.transcription.multiprocessing.decode_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.lm_rescore_func` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.carpa_lm_rescore_func` + Multiprocessing helper function for each job + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + self.logger.info("Beginning transcription...") + dirty_path = os.path.join(self.working_directory, "dirty") + if os.path.exists(dirty_path): + shutil.rmtree(self.working_directory, ignore_errors=True) + os.makedirs(self.working_log_directory, exist_ok=True) + try: + self.speaker_independent = True + jobs = self.decode_arguments() + + if self.use_mp: + run_mp(decode_func, jobs, self.working_log_directory) + else: + run_non_mp(decode_func, jobs, self.working_log_directory) + + jobs = self.lm_rescore_arguments() + + if self.use_mp: + run_mp(lm_rescore_func, jobs, self.working_log_directory) + else: + run_non_mp(lm_rescore_func, jobs, self.working_log_directory) + + jobs = self.carpa_lm_rescore_arguments() + + if self.use_mp: + run_mp(carpa_lm_rescore_func, jobs, self.working_log_directory) + else: + run_non_mp(carpa_lm_rescore_func, jobs, self.working_log_directory) + if self.uses_speaker_adaptation: + self.logger.info("Performing speaker adjusted transcription...") + self.transcribe_fmllr() + self.score_transcriptions() + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + def evaluate(self): + """ + Evaluates the transcripts if there are reference transcripts + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + self.logger.info("Evaluating transcripts...") + self._load_transcripts() + # Sentence-level measures + + correct = 0 + incorrect = 0 + # Word-level measures + total_edits = 0 + total_length = 0 + issues = [] + with mp.Pool(self.num_jobs) as pool: + to_comp = [] + for utt_name, utterance in self.utterances.items(): + g = utterance.text.split() + if not utterance.transcription_text: + incorrect += 1 + total_edits += len(g) + total_length += len(g) + + h = utterance.transcription_text.split() + if g != h: + issues.append((utt_name, g, h)) + to_comp.append((g, h)) + gen = pool.starmap(score, to_comp) + for (edits, length) in gen: + if edits == 0: + correct += 1 + else: + incorrect += 1 + total_edits += edits + total_length += length + ser = 100 * incorrect / (correct + incorrect) + wer = 100 * total_edits / total_length + output_path = os.path.join(self.evaluation_directory, "transcription_issues.csv") + with open(output_path, "w", encoding="utf8") as f: + for utt, g, h in issues: + g = " ".join(g) + h = " ".join(h) + f.write(f"{utt},{g},{h}\n") + self.logger.info(f"SER: {ser:.2f}%, WER: {wer:.2f}%") + return ser, wer + + def _load_transcripts(self): + """Load transcripts from Kaldi temporary files""" + for score_args in self.score_arguments(): + for tra_path in score_args.tra_paths.values(): + + with open(tra_path, "r", encoding="utf8") as f: + for line in f: + t = line.strip().split(" ") + utt = t[0] + utterance = self.utterances[utt] + speaker = utterance.speaker + lookup = speaker.dictionary.reversed_word_mapping + ints = t[1:] + if not ints: + continue + transcription = [] + for i in ints: + transcription.append(lookup[int(i)]) + utterance.transcription_text = " ".join(transcription) + + def export_files(self, output_directory: str) -> None: + """ + Export transcriptions + + Parameters + ---------- + output_directory: str + Directory to save transcriptions + """ + backup_output_directory = None + if not self.overwrite: + backup_output_directory = os.path.join(self.working_directory, "transcriptions") + os.makedirs(backup_output_directory, exist_ok=True) + self._load_transcripts() + for file in self.files.values(): + file.save(output_directory, backup_output_directory) diff --git a/montreal_forced_aligner/utils.py b/montreal_forced_aligner/utils.py index d045c876..5d6e1296 100644 --- a/montreal_forced_aligner/utils.py +++ b/montreal_forced_aligner/utils.py @@ -6,47 +6,45 @@ from __future__ import annotations import logging +import multiprocessing as mp import os import shutil import sys import textwrap -from typing import TYPE_CHECKING, Any, Dict, List, Union +import traceback +from queue import Empty +from typing import Any, Callable, Optional, Union -import yaml from colorama import Fore, Style from .exceptions import KaldiProcessingError, ThirdpartyError from .models import MODEL_TYPES -if TYPE_CHECKING: - from .config.base_config import BaseConfig - __all__ = [ "thirdparty_binary", - "get_available_dictionaries", - "log_config", "log_kaldi_errors", - "get_available_models", - "get_available_language_models", - "get_available_acoustic_models", - "get_available_g2p_models", - "get_pretrained_language_model_path", - "get_pretrained_g2p_path", - "get_pretrained_ivector_path", - "get_pretrained_path", - "get_pretrained_acoustic_path", - "get_dictionary_path", - "get_available_ivector_extractors", "guess_model_type", "parse_logs", - "setup_logger", "CustomFormatter", + "Counter", + "Stopped", + "ProcessWorker", + "run_mp", + "run_non_mp", ] -def get_mfa_version(): +def get_mfa_version() -> str: + """ + Get the current MFA version + + Returns + ------- + str + MFA version + """ try: - from .version import version as __version__ # noqa + from ._version import version as __version__ # noqa except ImportError: __version__ = "2.0.0" return __version__ @@ -79,49 +77,13 @@ def thirdparty_binary(binary_name: str) -> str: return bin_path -def parse_logs(log_directory: str) -> None: - """ - Parse the output of a Kaldi run for any errors and raise relevant MFA exceptions - - Parameters - ---------- - log_directory: str - Log directory to parse - - Raises - ------ - KaldiProcessingError - If any log files contained error lines - - """ - error_logs = [] - for name in os.listdir(log_directory): - log_path = os.path.join(log_directory, name) - with open(log_path, "r", encoding="utf8") as f: - for line in f: - line = line.strip() - if "error while loading shared libraries: libopenblas.so.0" in line: - raise ThirdpartyError("libopenblas.so.0", open_blas=True) - for libc_version in ["GLIBC_2.27", "GLIBCXX_3.4.20"]: - if libc_version in line: - raise ThirdpartyError(libc_version, libc=True) - if "sox FAIL formats" in line: - f = line.split(" ")[-1] - raise ThirdpartyError(f, sox=True) - if line.startswith("ERROR") or line.startswith("ASSERTION_FAILED"): - error_logs.append(log_path) - break - if error_logs: - raise KaldiProcessingError(error_logs) - - -def log_kaldi_errors(error_logs: List[str], logger: logging.Logger) -> None: +def log_kaldi_errors(error_logs: list[str], logger: logging.Logger) -> None: """ Save details of Kaldi processing errors to a logger Parameters ---------- - error_logs: List[str] + error_logs: list[str] Kaldi log files with errors logger: :class:`~logging.Logger` Logger to output to @@ -135,33 +97,7 @@ def log_kaldi_errors(error_logs: List[str], logger: logging.Logger) -> None: logger.debug("\t" + line.strip()) -def get_available_models(model_type: str) -> List[str]: - """ - Get a list of available models for a given model type - - Parameters - ---------- - model_type: str - Model type to search - - Returns - ------- - List[str] - List of model names - """ - from .config import TEMP_DIR - - pretrained_dir = os.path.join(TEMP_DIR, "pretrained_models", model_type) - os.makedirs(pretrained_dir, exist_ok=True) - available = [] - model_class = MODEL_TYPES[model_type] - for f in os.listdir(pretrained_dir): - if model_class is None or model_class.valid_extension(f): - available.append(os.path.splitext(f)[0]) - return available - - -def guess_model_type(path: str) -> List[str]: +def guess_model_type(path: str) -> list[str]: """ Guess a model type given a path @@ -172,7 +108,7 @@ def guess_model_type(path: str) -> List[str]: Returns ------- - List[str] + list[str] Possible model types that use that extension """ ext = os.path.splitext(path)[1] @@ -185,176 +121,6 @@ def guess_model_type(path: str) -> List[str]: return possible -def get_available_acoustic_models() -> List[str]: - """ - Return a list of all available acoustic models - - Returns - ------- - List[str] - Pretrained acoustic models - """ - return get_available_models("acoustic") - - -def get_available_g2p_models() -> List[str]: - """ - Return a list of all available G2P models - - Returns - ------- - List[str] - Pretrained G2P models - """ - return get_available_models("g2p") - - -def get_available_ivector_extractors() -> List[str]: - """ - Return a list of all available ivector extractors - - Returns - ------- - List[str] - Pretrained ivector extractors - """ - return get_available_models("ivector") - - -def get_available_language_models() -> List[str]: - """ - Return a list of all available language models - - Returns - ------- - List[str] - Pretrained language models - """ - return get_available_models("language_model") - - -def get_available_dictionaries() -> List[str]: - """ - Return a list of all available dictionaries - - Returns - ------- - List[str] - Saved dictionaries - """ - return get_available_models("dictionary") - - -def get_pretrained_path(model_type: str, name: str, enforce_existence: bool = True) -> str: - """ - Generate a path to a pretrained model based on its name and model type - - Parameters - ---------- - model_type: str - Type of model - name: str - Name of model - enforce_existence: bool - Flag to return None if the path doesn't exist, defaults to True - - Returns - ------- - str - Path to model - """ - from .config import TEMP_DIR - - pretrained_dir = os.path.join(TEMP_DIR, "pretrained_models", model_type) - model_class = MODEL_TYPES[model_type] - return model_class.generate_path(pretrained_dir, name, enforce_existence) - - -def get_pretrained_acoustic_path(name: str) -> str: - """ - Generate a path to a given pretrained acoustic model - - Parameters - ---------- - name: str - Name of model - - Returns - ------- - str - Full path to model - """ - return get_pretrained_path("acoustic", name) - - -def get_pretrained_ivector_path(name: str) -> str: - """ - Generate a path to a given pretrained ivector extractor - - Parameters - ---------- - name: str - Name of model - - Returns - ------- - str - Full path to model - """ - return get_pretrained_path("ivector", name) - - -def get_pretrained_language_model_path(name: str) -> str: - """ - Generate a path to a given pretrained language model - - Parameters - ---------- - name: str - Name of model - - Returns - ------- - str - Full path to model - """ - return get_pretrained_path("language_model", name) - - -def get_pretrained_g2p_path(name: str) -> str: - """ - Generate a path to a given pretrained G2P model - - Parameters - ---------- - name: str - Name of model - - Returns - ------- - str - Full path to model - """ - return get_pretrained_path("g2p", name) - - -def get_dictionary_path(name: str) -> str: - """ - Generate a path to a given saved dictionary - - Parameters - ---------- - name: str - Name of dictionary - - Returns - ------- - str - Full path to dictionary - """ - return get_pretrained_path("dictionary", name) - - class CustomFormatter(logging.Formatter): """ Custom log formatter class for MFA to highlight messages and incorporate terminal options from @@ -411,57 +177,248 @@ def format(self, record: logging.LogRecord): ) -def setup_logger( - identifier: str, output_directory: str, console_level: str = "info" -) -> logging.Logger: +def parse_logs(log_directory: str) -> None: + """ + Parse the output of a Kaldi run for any errors and raise relevant MFA exceptions + + Parameters + ---------- + log_directory: str + Log directory to parse + + Raises + ------ + KaldiProcessingError + If any log files contained error lines + + """ + error_logs = [] + for name in os.listdir(log_directory): + log_path = os.path.join(log_directory, name) + with open(log_path, "r", encoding="utf8") as f: + for line in f: + line = line.strip() + if "error while loading shared libraries: libopenblas.so.0" in line: + raise ThirdpartyError("libopenblas.so.0", open_blas=True) + for libc_version in ["GLIBC_2.27", "GLIBCXX_3.4.20"]: + if libc_version in line: + raise ThirdpartyError(libc_version, libc=True) + if "sox FAIL formats" in line: + f = line.split(" ")[-1] + raise ThirdpartyError(f, sox=True) + if line.startswith("ERROR") or line.startswith("ASSERTION_FAILED"): + error_logs.append(log_path) + break + if error_logs: + raise KaldiProcessingError(error_logs) + + +class Counter(object): + """ + Multiprocessing counter object for keeping track of progress + + Attributes + ---------- + val: :func:`~multiprocessing.Value` + Integer to increment + lock: :class:`~multiprocessing.Lock` + Lock for process safety + """ + + def __init__(self, init_val: int = 0): + self.val = mp.Value("i", init_val) + self.lock = mp.Lock() + + def increment(self) -> None: + """Increment the counter""" + with self.lock: + self.val.value += 1 + + def value(self) -> int: + """Get the current value of the counter""" + with self.lock: + return self.val.value + + +class Stopped(object): + """ + Multiprocessing class for detecting whether processes should stop processing and exit ASAP + + Attributes + ---------- + val: :func:`~multiprocessing.Value` + 0 if not stopped, 1 if stopped + lock: :class:`~multiprocessing.Lock` + Lock for process safety + _source: multiprocessing.Value + 1 if it was a Ctrl+C event that stopped it, 0 otherwise + """ + + def __init__(self, initval: Union[bool, int] = False): + self.val = mp.Value("i", initval) + self.lock = mp.Lock() + self._source = mp.Value("i", 0) + + def stop(self) -> None: + """Signal that work should stop asap""" + with self.lock: + self.val.value = True + + def stop_check(self) -> int: + """Check whether a process should stop""" + with self.lock: + return self.val.value + + def set_sigint_source(self) -> None: + """Set the source as a ctrl+c""" + with self.lock: + self._source.value = True + + def source(self) -> int: + """Get the source value""" + with self.lock: + return self._source.value + + +class ProcessWorker(mp.Process): """ - Construct a logger for a command line run + Multiprocessing function work + + Parameters + ---------- + job_name: int + Integer number of job + job_q: :class:`~multiprocessing.Queue` + Job queue to pull arguments from + function: Callable + Multiprocessing function to call on arguments from job_q + return_dict: dict + Dictionary for collecting errors + stopped: :class:`~montreal_forced_aligner.utils.Stopped` + Stop check + return_info: dict[int, Any], optional + Optional dictionary to fill if the function should return information to main thread + """ + + def __init__( + self, + job_name: int, + job_q: mp.Queue, + function: Callable, + return_dict: dict, + stopped: Stopped, + return_info: Optional[dict[int, Any]] = None, + ): + mp.Process.__init__(self) + self.job_name = job_name + self.function = function + self.job_q = job_q + self.return_dict = return_dict + self.return_info = return_info + self.stopped = stopped + + def run(self) -> None: + """ + Run through the arguments in the queue apply the function to them + """ + try: + arguments = self.job_q.get(timeout=1) + except Empty: + return + self.job_q.task_done() + try: + result = self.function(*arguments) + if self.return_info is not None: + self.return_info[self.job_name] = result + except Exception: + self.stopped.stop() + self.return_dict["error"] = arguments, Exception( + traceback.format_exception(*sys.exc_info()) + ) + + +def run_non_mp( + function: Callable, + argument_list: list[tuple[Any, ...]], + log_directory: str, + return_info: bool = False, +) -> Optional[dict[Any, Any]]: + """ + Similar to :func:`run_mp`, but no additional processes are used and the jobs are evaluated in sequential order Parameters ---------- - identifier: str - Name of the MFA utility - output_directory: str - Top level logging directory - console_level: str, optional - Level to output to the console, defaults to "info" + function: Callable + Multiprocessing function to evaluate + argument_list: list + List of arguments to process + log_directory: str + Directory that all log information from the processes goes to + return_info: dict, optional + If the function returns information, supply the return dict to populate Returns ------- - :class:`~logging.Logger` - Logger to use + dict, optional + If the function returns information, returns the dictionary it was supplied with """ - os.makedirs(output_directory, exist_ok=True) - log_path = os.path.join(output_directory, f"{identifier}.log") - if os.path.exists(log_path): - os.remove(log_path) - logger = logging.getLogger(identifier) - logger.setLevel(logging.DEBUG) - - handler = logging.FileHandler(log_path, encoding="utf8") - handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(getattr(logging, console_level.upper())) - handler.setFormatter(CustomFormatter()) - logger.addHandler(handler) - logger.debug(f"Set up logger for MFA version: {get_mfa_version()}") - return logger - - -def log_config(logger: logging.Logger, config: Union[Dict[str, Any], BaseConfig]) -> None: + if return_info: + info = {} + for i, args in enumerate(argument_list): + info[i] = function(*args) + parse_logs(log_directory) + return info + + for args in argument_list: + function(*args) + parse_logs(log_directory) + + +def run_mp( + function: Callable, + argument_list: list[tuple[Any, ...]], + log_directory: str, + return_info: bool = False, +) -> Optional[dict[int, Any]]: """ - Output a configuration to a Logger + Apply a function for each job in parallel Parameters ---------- - logger: :class:`~logging.Logger` - Logger to save to - config: Dict[str, Any] or :class:`~montreal_forced_aligner.config.BaseConfig` - Configuration to dump - """ - stream = yaml.dump(config) - logger.debug(stream) + function: Callable + Multiprocessing function to apply + argument_list: list + List of arguments for each job + log_directory: str + Directory that all log information from the processes goes to + return_info: dict, optional + If the function returns information, supply the return dict to populate + """ + from .config import BLAS_THREADS + + os.environ["OPENBLAS_NUM_THREADS"] = f"{BLAS_THREADS}" + os.environ["MKL_NUM_THREADS"] = f"{BLAS_THREADS}" + stopped = Stopped() + manager = mp.Manager() + job_queue = manager.Queue() + return_dict = manager.dict() + info = None + if return_info: + info = manager.dict() + for a in argument_list: + job_queue.put(a) + procs = [] + for i in range(len(argument_list)): + p = ProcessWorker(i, job_queue, function, return_dict, stopped, info) + procs.append(p) + p.start() + + for p in procs: + p.join() + if "error" in return_dict: + _, exc = return_dict["error"] + raise exc + + parse_logs(log_directory) + if return_info: + return info diff --git a/montreal_forced_aligner/validator.py b/montreal_forced_aligner/validator.py index c8fd7b84..986dd1d4 100644 --- a/montreal_forced_aligner/validator.py +++ b/montreal_forced_aligner/validator.py @@ -5,136 +5,364 @@ """ from __future__ import annotations -import logging import os +import subprocess +import time from decimal import Decimal -from typing import TYPE_CHECKING, Optional - -from .abc import AcousticModelWorker -from .aligner.pretrained import PretrainedAligner -from .config import FeatureConfig -from .exceptions import CorpusError, KaldiProcessingError -from .helper import edit_distance, load_scp -from .multiprocessing import run_mp, run_non_mp -from .multiprocessing.alignment import compile_utterance_train_graphs_func, test_utterances_func -from .trainers import MonophoneTrainer -from .utils import log_kaldi_errors +from typing import TYPE_CHECKING, Any, NamedTuple, Optional + +import yaml + +from .acoustic_modeling.trainer import TrainableAligner +from .alignment import CorpusAligner, PretrainedAligner +from .alignment.multiprocessing import compile_information_func +from .exceptions import ConfigError, KaldiProcessingError +from .helper import TerminalPrinter, edit_distance, load_scp, save_scp +from .utils import log_kaldi_errors, run_mp, run_non_mp, thirdparty_binary if TYPE_CHECKING: - from .corpus.base import Corpus - from .dictionary import MultispeakerDictionary + from argparse import Namespace + + from .abc import MetaDict + + +__all__ = ["TrainingValidator", "PretrainedValidator"] + + +class CompileUtteranceTrainGraphsArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.validator.compile_utterance_train_graphs_func`""" + + log_path: str + dictionaries: list[str] + disambig_int_paths: dict[str, str] + disambig_L_fst_paths: dict[str, str] + fst_paths: dict[str, str] + graphs_paths: dict[str, str] + model_path: str + tree_path: str + +class TestUtterancesArguments(NamedTuple): + """Arguments for :func:`~montreal_forced_aligner.validator.test_utterances_func`""" -__all__ = ["CorpusValidator"] + log_path: str + dictionaries: list[str] + feature_strings: dict[str, str] + words_paths: dict[str, str] + graphs_paths: dict[str, str] + text_int_paths: dict[str, str] + edits_paths: dict[str, str] + out_int_paths: dict[str, str] + model_path: str -class CorpusValidator(AcousticModelWorker): +def test_utterances_func( + log_path: str, + dictionaries: list[str], + feature_strings: dict[str, str], + words_paths: dict[str, str], + graphs_paths: dict[str, str], + text_int_paths: dict[str, str], + edits_paths: dict[str, str], + out_int_paths: dict[str, str], + model_path: str, +): """ - Validator class for checking whether a corpus, a dictionary, and (optionally) an acoustic model work together + Multiprocessing function to test utterance transcriptions - Parameters - ---------- - corpus : :class:`~montreal_forced_aligner.corpus.Corpus` - Corpus object for the dataset - dictionary : :class:`~montreal_forced_aligner.dictionary.MultispeakerDictionary` - MultispeakerDictionary object for the pronunciation dictionary - temp_directory : str, optional - Specifies the temporary directory root to save files need for Kaldi. - If not specified, it will be set to ``~/Documents/MFA`` - ignore_acoustics: bool, optional - Flag for whether all acoustics should be ignored, which speeds up the validation, defaults to False - test_transcriptions: bool, optional - Flag for whether the validator should test transcriptions, defaults to False - use_mp: bool, optional - Flag for whether to use multiprocessing - logger: :class:`~logging.Logger`, optional - Logger to use + See Also + -------- + :kaldi_src:`gmm-latgen-faster` + Relevant Kaldi binary + :kaldi_src:`lattice-oracle` + Relevant Kaldi binary - Attributes + Parameters ---------- - corpus_analysis_template: str - Template for output message - alignment_analysis_template: str - Template for output message - transcription_analysis_template: str - Template for output message + log_path: str + Log path + dictionaries: list[str] + List of dictionaries + feature_strings: dict[str, str] + Dictionary of feature strings per dictionary name + words_paths: dict[str, str] + Dictionary of word mapping files per dictionary name + graphs_paths: dict[str, str] + Dictionary of utterance FST graph archives per dictionary name + text_int_paths: dict[str, str] + Dictionary of text.int files per dictionary name + edits_paths: dict[str, str] + Dictionary of paths to save transcription differences per dictionary name + out_int_paths: dict[str, str] + Dictionary of output .int files per dictionary name + model_path: str + Acoustic model path """ + acoustic_scale = 0.1 + beam = 15.0 + lattice_beam = 8.0 + max_active = 750 + with open(log_path, "w") as log_file: + for dict_name in dictionaries: + words_path = words_paths[dict_name] + graphs_path = graphs_paths[dict_name] + feature_string = feature_strings[dict_name] + edits_path = edits_paths[dict_name] + text_int_path = text_int_paths[dict_name] + out_int_path = out_int_paths[dict_name] + latgen_proc = subprocess.Popen( + [ + thirdparty_binary("gmm-latgen-faster"), + f"--acoustic-scale={acoustic_scale}", + f"--beam={beam}", + f"--max-active={max_active}", + f"--lattice-beam={lattice_beam}", + f"--word-symbol-table={words_path}", + model_path, + "ark:" + graphs_path, + feature_string, + "ark:-", + ], + stderr=log_file, + stdout=subprocess.PIPE, + ) - corpus_analysis_template = """ - =========================================Corpus========================================= - {} sound files - {} sound files with .lab transcription files - {} sound files with TextGrids transcription files - {} additional sound files ignored (see below) - {} speakers - {} utterances - {} seconds total duration - - DICTIONARY - ---------- - {} + oracle_proc = subprocess.Popen( + [ + thirdparty_binary("lattice-oracle"), + "ark:-", + f"ark,t:{text_int_path}", + f"ark,t:{out_int_path}", + f"ark,t:{edits_path}", + ], + stderr=log_file, + stdin=latgen_proc.stdout, + ) + oracle_proc.communicate() - SOUND FILE READ ERRORS - ---------------------- - {} - FEATURE CALCULATION - ------------------- - {} +def compile_utterance_train_graphs_func( + log_path: str, + dictionaries: list[str], + disambig_int_paths: dict[str, str], + disambig_L_fst_paths: dict[str, str], + fst_paths: dict[str, str], + graphs_paths: dict[str, str], + model_path: str, + tree_path: str, +): + """ + Multiprocessing function to compile utterance FSTs - FILES WITHOUT TRANSCRIPTIONS - ---------------------------- - {} + See Also + -------- + :kaldi_src:`compile-train-graphs-fsts` + Relevant Kaldi binary - TRANSCRIPTIONS WITHOUT FILES - -------------------- - {} + Parameters + ---------- + log_path: str + Log path + dictionaries: list[str] + List of dictionaries + disambig_int_paths: dict[str, str] + Dictionary of disambiguation symbol int files per dictionary name + disambig_L_fst_paths: dict[str, str] + Dictionary of disambiguation lexicon FSTs per dictionary name + fst_paths: dict[str, str] + Dictionary of pregenerated utterance FST scp files per dictionary name + graphs_paths: dict[str, str] + Dictionary of utterance FST graph archives per dictionary name + model_path: str + Acoustic model path + tree_path: str + Acoustic model's tree path + """ + with open(log_path, "w") as log_file: + for dict_name in dictionaries: + disambig_int_path = disambig_int_paths[dict_name] + disambig_L_fst_path = disambig_L_fst_paths[dict_name] + fst_path = fst_paths[dict_name] + graphs_path = graphs_paths[dict_name] + proc = subprocess.Popen( + [ + thirdparty_binary("compile-train-graphs-fsts"), + "--transition-scale=1.0", + "--self-loop-scale=0.1", + f"--read-disambig-syms={disambig_int_path}", + tree_path, + model_path, + disambig_L_fst_path, + f"ark:{fst_path}", + f"ark:{graphs_path}", + ], + stderr=log_file, + ) - TEXTGRID READ ERRORS - -------------------- - {} + proc.communicate() - UNREADABLE TEXT FILES - -------------------- - {} - """ - alignment_analysis_template = """ - =======================================Alignment======================================== - {} +class ValidationMixin(CorpusAligner): """ + Mixin class for performing validation on a corpus + + Parameters + ---------- + ignore_acoustics: bool + Flag for whether feature generation and training/alignment should be skipped + test_transcriptions: bool + Flag for whether utterance transcriptions should be tested with a unigram language model - transcription_analysis_template = """ - ====================================Transcriptions====================================== - {} + See Also + -------- + :class:`~montreal_forced_aligner.alignment.base.CorpusAligner` + For corpus, dictionary, and alignment parameters + + Attributes + ---------- + printer: TerminalPrinter + Printer for output messages """ def __init__( - self, - corpus: Corpus, - dictionary: MultispeakerDictionary, - temp_directory: Optional[str] = None, - ignore_acoustics: bool = False, - test_transcriptions: bool = False, - use_mp: bool = True, - logger: Optional[logging.Logger] = None, + self, ignore_acoustics: bool = False, test_transcriptions: bool = False, **kwargs ): - super().__init__(corpus, dictionary) - self.temp_directory = temp_directory - self.test_transcriptions = test_transcriptions + kwargs["clean"] = True + super().__init__(**kwargs) self.ignore_acoustics = ignore_acoustics - self.trainer: MonophoneTrainer = MonophoneTrainer(FeatureConfig()) - self.logger = logger - self.trainer.logger = logger - self.trainer.update({"use_mp": use_mp}) - self.setup() + self.test_transcriptions = test_transcriptions + self.printer = TerminalPrinter() @property - def working_directory(self) -> str: - return os.path.join(self.temp_directory, "validation") + def workflow_identifier(self) -> str: + """Identifier for validation""" + return "validation" + + def utt2fst_scp_data( + self, num_frequent_words: int = 10 + ) -> list[dict[str, list[tuple[str, str]]]]: + """ + Generate Kaldi style utt2fst scp data + + Parameters + ---------- + num_frequent_words: int + Number of frequent words to include in the unigram language model + + Returns + ------- + dict[str, list[tuple[str, str]]] + Utterance FSTs per dictionary name + """ + job_data = [] + most_frequent = {} + for j in self.jobs: + data = {} + utts = j.job_utts() + for dict_name, utt_data in utts.items(): + data[dict_name] = [] + for u_name, utterance in utt_data.items(): + new_text = [] + dictionary = utterance.speaker.dictionary + if dict_name not in most_frequent: + word_frequencies = self.get_word_frequency() + most_frequent[dict_name] = sorted( + word_frequencies.items(), key=lambda x: -x[1] + )[:num_frequent_words] + + for t in utterance.text: + lookup = utterance.speaker.dictionary.split_clitics(t) + if lookup is None: + continue + new_text.extend(x for x in lookup if x != "") + data[dict_name].append( + ( + u_name, + dictionary.create_utterance_fst( + new_text, most_frequent[dictionary.name] + ), + ) + ) + job_data.append(data) + return job_data + + def output_utt_fsts(self, num_frequent_words: int = 10) -> None: + """ + Write utterance FSTs + + Parameters + ---------- + num_frequent_words: int + Number of frequent words + """ + utt2fst = self.utt2fst_scp_data(num_frequent_words) + for i, job_data in enumerate(utt2fst): + for dict_name, scp in job_data.items(): + utt2fst_scp_path = os.path.join( + self.split_directory, f"utt2fst.{dict_name}.{i}.scp" + ) + save_scp(scp, utt2fst_scp_path, multiline=True) + + def compile_utterance_train_graphs_arguments( + self, + ) -> list[CompileUtteranceTrainGraphsArguments]: + """ + Generate Job arguments for :func:`compile_utterance_train_graphs_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.validator.CompileUtteranceTrainGraphsArguments`] + Arguments for processing + """ + disambig_paths = { + k: self.disambiguation_symbols_int_path for k, v in self.dictionary_mapping.items() + } + lexicon_fst_paths = { + k: v.lexicon_disambig_fst_path for k, v in self.dictionary_mapping.items() + } + return [ + CompileUtteranceTrainGraphsArguments( + os.path.join(self.working_log_directory, f"utterance_fst.{j.name}.log"), + j.current_dictionary_names, + disambig_paths, + lexicon_fst_paths, + j.construct_path_dictionary(self.data_directory, "utt2fst", "scp"), + j.construct_path_dictionary(self.working_directory, "utterance_graphs", "fst"), + self.model_path, + self.tree_path, + ) + for j in self.jobs + ] + + def test_utterances_arguments(self) -> list[TestUtterancesArguments]: + """ + Generate Job arguments for :func:`test_utterances_func` + + Returns + ------- + list[:class:`~montreal_forced_aligner.validator.TestUtterancesArguments`] + Arguments for processing + """ + feat_strings = self.construct_feature_proc_strings() + words_paths = {k: v.words_symbol_path for k, v in self.dictionary_mapping.items()} + return [ + TestUtterancesArguments( + os.path.join(self.working_directory, f"utterance_fst.{j.name}.log"), + j.current_dictionary_names, + feat_strings[j.name], + words_paths, + j.construct_path_dictionary(self.working_directory, "utterance_graphs", "fst"), + j.construct_path_dictionary(self.data_directory, "text", "int.scp"), + j.construct_path_dictionary(self.working_directory, "edits", "scp"), + j.construct_path_dictionary(self.working_directory, "aligned", "int"), + self.model_path, + ) + for j in self.jobs + ] @property def working_log_directory(self) -> str: + """Working log directory""" return os.path.join(self.working_directory, "log") def setup(self): @@ -146,130 +374,212 @@ def setup(self): :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ - self.dictionary.set_word_set(self.corpus.word_set) - self.dictionary.write() - if self.test_transcriptions: - self.dictionary.write(write_disambiguation=True) - if self.ignore_acoustics: - fc = None - if self.logger is not None: - self.logger.info("Skipping acoustic feature generation") - else: - fc = self.trainer.feature_config try: - self.corpus.initialize_corpus(self.dictionary, fc) + self.load_corpus() + self.write_lexicon_information() if self.test_transcriptions: - self.corpus.initialize_utt_fsts() - except CorpusError: - if self.logger is not None: - self.logger.warning( - "There was an error when initializing the corpus, likely due to missing sound files. Ignoring acoustic generation..." - ) - self.ignore_acoustics = True + self.write_lexicon_information(write_disambiguation=True) + if self.ignore_acoustics: + self.logger.info("Skipping acoustic feature generation") + else: + self.generate_features() + + if self.test_transcriptions: + self.initialize_utt_fsts() + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + @property + def indent_string(self) -> str: + """Indent string to use in formatting the output messages""" + return " " * 4 + + def _print_header(self, header: str) -> None: + """ + Print a section header + + Parameters + ---------- + header: str + Section header string + """ + print() + underline = "*" * len(header) + print(self.printer.colorize(underline, "bright")) + print(self.printer.colorize(header, "bright")) + print(self.printer.colorize(underline, "bright")) + + def _print_sub_header(self, header: str) -> None: + """ + Print a subsection header - def analyze_setup(self): + Parameters + ---------- + header: str + Subsection header string + """ + underline = "=" * len(header) + print(self.printer.colorize(header, "bright")) + print(self.printer.colorize(underline, "bright")) + + def _print_green_stat(self, stat: Any, text: str) -> None: + """ + Print a statistic in green + + Parameters + ---------- + stat: Any + Statistic to print + text: str + Other text to follow statistic + """ + print(self.indent_string + f"{self.printer.colorize(stat, 'green')} {text}") + + def _print_yellow_stat(self, stat, text) -> None: + """ + Print a statistic in yellow + + Parameters + ---------- + stat: Any + Statistic to print + text: str + Other text to follow statistic + """ + print(self.indent_string + f"{self.printer.colorize(stat, 'yellow')} {text}") + + def _print_red_stat(self, stat, text) -> None: + """ + Print a statistic in red + + Parameters + ---------- + stat: Any + Statistic to print + text: str + Other text to follow statistic + """ + print(self.indent_string + f"{self.printer.colorize(stat, 'red')} {text}") + + def analyze_setup(self) -> None: """ Analyzes the set up process and outputs info to the console """ - total_duration = sum(x.duration for x in self.corpus.files.values()) + total_duration = sum(x.duration for x in self.files.values()) total_duration = Decimal(str(total_duration)).quantize(Decimal("0.001")) - ignored_count = len(self.corpus.no_transcription_files) - ignored_count += len(self.corpus.textgrid_read_errors) - ignored_count += len(self.corpus.decode_error_files) - self.logger.info( - self.corpus_analysis_template.format( - sum(1 for x in self.corpus.files.values() if x.wav_path is not None), - sum(1 for x in self.corpus.files.values() if x.text_type == "lab"), - sum(1 for x in self.corpus.files.values() if x.text_type == "textgrid"), - ignored_count, - len(self.corpus.speakers), - self.corpus.num_utterances, - total_duration, - self.analyze_oovs(), - self.analyze_wav_errors(), - self.analyze_missing_features(), - self.analyze_files_with_no_transcription(), - self.analyze_transcriptions_with_no_wavs(), - self.analyze_textgrid_read_errors(), - self.analyze_unreadable_text_files(), + ignored_count = len(self.no_transcription_files) + ignored_count += len(self.textgrid_read_errors) + ignored_count += len(self.decode_error_files) + num_sound_files = sum(1 for x in self.files.values() if x.wav_path is not None) + num_lab_files = sum(1 for x in self.files.values() if x.text_type == "lab") + num_textgrid_files = sum(1 for x in self.files.values() if x.text_type == "textgrid") + self._print_header("Corpus") + self._print_green_stat(num_sound_files, "sound files") + self._print_green_stat(num_lab_files, "lab files") + self._print_green_stat(num_textgrid_files, "textgrid files") + if len(self.no_transcription_files): + self._print_yellow_stat( + len(self.no_transcription_files), + "sound files without corresponding transcriptions", ) - ) + if len(self.decode_error_files): + self._print_red_stat(len(self.decode_error_files), "read errors for lab files") + if len(self.textgrid_read_errors): + self._print_red_stat(len(self.textgrid_read_errors), "read errors for TextGrid files") + + self._print_green_stat(len(self.speakers), "speakers") + self._print_green_stat(self.num_utterances, "utterances") + self._print_green_stat(total_duration, "seconds total duration") + print() + + self.analyze_oovs() + self.analyze_wav_errors() + self.analyze_missing_features() + self.analyze_files_with_no_transcription() + self.analyze_transcriptions_with_no_wavs() - def analyze_oovs(self) -> str: + if len(self.decode_error_files) or num_lab_files: + self.analyze_unreadable_text_files() + if len(self.textgrid_read_errors) or num_textgrid_files: + self.analyze_textgrid_read_errors() + + def analyze_oovs(self) -> None: """ Analyzes OOVs in the corpus and constructs message - - Returns - ------- - str - OOV validation result message """ - output_dir = self.corpus.output_directory - oov_types = self.dictionary.oovs_found + self._print_sub_header("Dictionary") + output_dir = self.output_directory + oov_types = self.oovs_found oov_path = os.path.join(output_dir, "oovs_found.txt") utterance_oov_path = os.path.join(output_dir, "utterance_oovs.txt") if oov_types: total_instances = 0 with open(utterance_oov_path, "w", encoding="utf8") as f: - for utt, utterance in sorted(self.corpus.utterances.items()): + for utt, utterance in sorted(self.utterances.items()): if not utterance.oovs: continue total_instances += len(utterance.oovs) f.write(f"{utt} {', '.join(utterance.oovs)}\n") - self.dictionary.save_oovs_found(output_dir) - message = ( - f"There were {len(oov_types)} word types not found in the dictionary with a total of {total_instances} instances.\n\n" - f" Please see \n\n {oov_path}\n\n for a full list of the word types and \n\n {utterance_oov_path}\n\n for a by-utterance breakdown of " - f"missing words." + self.save_oovs_found(output_dir) + self._print_yellow_stat(len(oov_types), "OOV word types") + self._print_yellow_stat(total_instances, "total OOV tokens") + print() + print(self.indent_string + "For a full list of the word types, please see:") + print() + print( + self.indent_string + self.indent_string + self.printer.colorize(oov_path, "bright") + ) + print() + print(self.indent_string + "For a by-utterance breakdown of missing words, see:") + print() + + print( + self.indent_string + + self.indent_string + + self.printer.colorize(utterance_oov_path, "bright") ) else: - message = ( - "There were no missing words from the dictionary. If you plan on using the a model trained " + print( + f"There were {self.printer.colorize('no', 'yellow')} missing words from the dictionary. If you plan on using the a model trained " "on this dataset to align other datasets in the future, it is recommended that there be at " "least some missing words." ) - return message + print() - def analyze_wav_errors(self) -> str: + def analyze_wav_errors(self) -> None: """ Analyzes any sound file issues in the corpus and constructs message - - Returns - ------- - str - Sound file validation result message """ - output_dir = self.corpus.output_directory - wav_read_errors = self.corpus.sound_file_errors + self._print_sub_header("Sound file read errors") + output_dir = self.output_directory + wav_read_errors = self.sound_file_errors if wav_read_errors: path = os.path.join(output_dir, "sound_file_errors.csv") with open(path, "w") as f: for p in wav_read_errors: f.write(f"{p}\n") - message = ( - f"There were {len(wav_read_errors)} sound files that could not be read. " - f"Please see {path} for a list." + print( + f"There were {self.printer.colorize(len(wav_read_errors), 'red')} issues reading sound files. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." ) else: - message = "There were no sound files that could not be read." + print(f"There were {self.printer.colorize('no', 'green')} issues reading sound files.") + print() - return message - - def analyze_missing_features(self) -> str: + def analyze_missing_features(self) -> None: """ Analyzes issues in feature generation in the corpus and constructs message - - Returns - ------- - str - Feature validation result message """ + self._print_sub_header("Feature generation") if self.ignore_acoustics: - return "Acoustic feature generation was skipped." - output_dir = self.corpus.output_directory - missing_features = [x for x in self.corpus.utterances.values() if x.ignored] + print("Acoustic feature generation was skipped.") + output_dir = self.output_directory + missing_features = [x for x in self.utterances.values() if x.ignored] if missing_features: path = os.path.join(output_dir, "missing_features.csv") with open(path, "w") as f: @@ -280,159 +590,205 @@ def analyze_missing_features(self) -> str: else: f.write(f"{utt.file.wav_path}\n") - message = ( - f"There were {len(missing_features)} utterances missing features. " - f"Please see {path} for a list." + print( + f"There were {self.printer.colorize(len(missing_features), 'red')} utterances missing features. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." ) else: - message = "There were no utterances missing features." - return message + print( + f"There were {self.printer.colorize('no', 'green')} utterances missing features." + ) + print() - def analyze_files_with_no_transcription(self) -> str: + def analyze_files_with_no_transcription(self) -> None: """ Analyzes issues with sound files that have no transcription files in the corpus and constructs message - - Returns - ------- - str - File matching validation result message """ - output_dir = self.corpus.output_directory - if self.corpus.no_transcription_files: + self._print_sub_header("Files without transcriptions") + output_dir = self.output_directory + if self.no_transcription_files: path = os.path.join(output_dir, "missing_transcriptions.csv") with open(path, "w") as f: - for file_path in self.corpus.no_transcription_files: + for file_path in self.no_transcription_files: f.write(f"{file_path}\n") - message = ( - f"There were {len(self.corpus.no_transcription_files)} sound files missing transcriptions. " - f"Please see {path} for a list." + print( + f"There were {self.printer.colorize(len(self.no_transcription_files), 'red')} sound files missing transcriptions. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." ) else: - message = "There were no sound files missing transcriptions." - return message + print( + f"There were {self.printer.colorize('no', 'green')} sound files missing transcriptions." + ) + print() - def analyze_transcriptions_with_no_wavs(self) -> str: + def analyze_transcriptions_with_no_wavs(self) -> None: """ Analyzes issues with transcription that have no sound files in the corpus and constructs message - - Returns - ------- - str - File matching validation result message """ - output_dir = self.corpus.output_directory - if self.corpus.transcriptions_without_wavs: + self._print_sub_header("Transcriptions without sound files") + output_dir = self.output_directory + if self.transcriptions_without_wavs: path = os.path.join(output_dir, "transcriptions_missing_sound_files.csv") with open(path, "w") as f: - for file_path in self.corpus.transcriptions_without_wavs: + for file_path in self.transcriptions_without_wavs: f.write(f"{file_path}\n") - message = ( - f"There were {len(self.corpus.transcriptions_without_wavs)} transcription files missing sound files. " - f"Please see {path} for a list." + print( + f"There were {self.printer.colorize(len(self.transcriptions_without_wavs), 'red')} transcription files missing sound files. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." ) else: - message = "There were no transcription files missing sound files." - return message + print( + f"There were {self.printer.colorize('no', 'green')} transcription files missing sound files." + ) + print() - def analyze_textgrid_read_errors(self) -> str: + def analyze_textgrid_read_errors(self) -> None: """ Analyzes issues with reading TextGrid files in the corpus and constructs message - - Returns - ------- - str - TextGrid validation result message """ - output_dir = self.corpus.output_directory - if self.corpus.textgrid_read_errors: + self._print_sub_header("TextGrid read errors") + output_dir = self.output_directory + if self.textgrid_read_errors: path = os.path.join(output_dir, "textgrid_read_errors.txt") with open(path, "w") as f: - for k, v in self.corpus.textgrid_read_errors.items(): + for k, v in self.textgrid_read_errors.items(): f.write( f"The TextGrid file {k} gave the following error on load:\n\n{v}\n\n\n" ) - message = ( - f"There were {len(self.corpus.textgrid_read_errors)} TextGrid files that could not be parsed. " - f"Please see {path} for a list." + print( + f"There were {self.printer.colorize(len(self.textgrid_read_errors), 'red')} TextGrid files that could not be loaded. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." ) else: - message = "There were no issues reading TextGrids." - return message + print(f"There were {self.printer.colorize('no', 'green')} issues reading TextGrids.") + print() - def analyze_unreadable_text_files(self) -> str: + def analyze_unreadable_text_files(self) -> None: """ Analyzes issues with reading text files in the corpus and constructs message - - Returns - ------- - str - Text file validation result message """ - output_dir = self.corpus.output_directory - if self.corpus.decode_error_files: + self._print_sub_header("Text file read errors") + output_dir = self.output_directory + if self.decode_error_files: path = os.path.join(output_dir, "utf8_read_errors.csv") with open(path, "w") as f: - for file_path in self.corpus.decode_error_files: + for file_path in self.decode_error_files: f.write(f"{file_path}\n") - message = ( - f"There were {len(self.corpus.decode_error_files)} text files that could not be parsed. " - f"Please see {path} for a list." + print( + f"There were {self.printer.colorize(len(self.decode_error_files), 'red')} text files that could not be read. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." ) else: - message = "There were no issues reading text files." - return message + print(f"There were {self.printer.colorize('no', 'green')} issues reading text files.") + print() - def analyze_unaligned_utterances(self) -> None: + def compile_information(self) -> None: """ - Analyzes issues with any unaligned files following training + Compiles information about alignment, namely what the overall log-likelihood was + and how many files were unaligned. + + See Also + -------- + :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_information_func` + Multiprocessing helper function for each job + :meth:`.AlignMixin.compile_information_arguments` + Job method for generating arguments for the helper function """ - unaligned_utts = self.trainer.get_unaligned_utterances() - num_utterances = self.corpus.num_utterances - if unaligned_utts: - path = os.path.join(self.corpus.output_directory, "unalignable_files.csv") - with open(path, "w") as f: - f.write("File path,begin,end,duration,text length\n") - for utt in unaligned_utts: - utterance = self.corpus.utterances[utt] - utt_duration = utterance.duration - utt_length_words = utterance.text.count(" ") + 1 - if utterance.begin is not None: - f.write( - f"{utterance.file.wav_path},{utterance.begin},{utterance.end},{utt_duration},{utt_length_words}\n" - ) - else: - f.write(f"{utterance.file.wav_path},,,{utt_duration},{utt_length_words}\n") - message = ( - f"There were {len(unaligned_utts)} unalignable utterances out of {num_utterances} after the initial training. " - f"Please see {path} for a list." + self.logger.debug("Analyzing alignment information") + compile_info_begin = time.time() + + jobs = self.compile_information_arguments() + + if self.use_mp: + alignment_info = run_mp( + compile_information_func, jobs, self.working_log_directory, True ) else: - message = f"All {num_utterances} utterances were successfully aligned!" - print(self.alignment_analysis_template.format(message)) + alignment_info = run_non_mp( + compile_information_func, jobs, self.working_log_directory, True + ) - def validate(self): + avg_like_sum = 0 + avg_like_frames = 0 + average_logdet_sum = 0 + average_logdet_frames = 0 + beam_too_narrow_count = 0 + too_short_count = 0 + unaligned_utts = [] + for data in alignment_info.values(): + unaligned_utts.extend(data["unaligned"]) + beam_too_narrow_count += len(data["unaligned"]) + too_short_count += len(data["too_short"]) + avg_like_frames += data["total_frames"] + avg_like_sum += data["log_like"] * data["total_frames"] + if "logdet_frames" in data: + average_logdet_frames += data["logdet_frames"] + average_logdet_sum += data["logdet"] * data["logdet_frames"] + + if not avg_like_frames: + self.logger.debug( + "No utterances were aligned, this likely indicates serious problems with the aligner." + ) + self._print_red_stat(0, f"of {len(self.utterances)} utterances were aligned") + else: + if too_short_count: + self._print_red_stat(too_short_count, "utterances were too short to be aligned") + else: + self._print_green_stat(0, "utterances were too short to be aligned") + if beam_too_narrow_count: + self.logger.debug( + f"There were {beam_too_narrow_count} utterances that could not be aligned with " + f"the current beam settings." + ) + self._print_yellow_stat( + beam_too_narrow_count, "utterances that need a larger beam to align" + ) + else: + self._print_green_stat(0, "utterances that need a larger beam to align") + + num_utterances = self.num_utterances + if unaligned_utts: + path = os.path.join(self.output_directory, "unalignable_files.csv") + with open(path, "w") as f: + f.write("File path,begin,end,duration,text length\n") + for utt in unaligned_utts: + utterance = self.utterances[utt] + utt_duration = utterance.duration + utt_length_words = utterance.text.count(" ") + 1 + if utterance.begin is not None: + f.write( + f"{utterance.file.wav_path},{utterance.begin},{utterance.end},{utt_duration},{utt_length_words}\n" + ) + else: + f.write( + f"{utterance.file.wav_path},,,{utt_duration},{utt_length_words}\n" + ) + print( + f"There were {self.printer.colorize(len(unaligned_utts), 'red')} unaligned utterances out of {self.printer.colorize(self.num_utterances, 'bright')} after initial training. " + f"Please see {self.printer.colorize(path, 'bright')} for a list." + ) + + self._print_green_stat( + num_utterances - beam_too_narrow_count - too_short_count, + "utterances were successfully aligned", + ) + average_log_like = avg_like_sum / avg_like_frames + if average_logdet_sum: + average_log_like += average_logdet_sum / average_logdet_frames + self.logger.debug(f"Average per frame likelihood for alignment: {average_log_like}") + self.logger.debug(f"Compiling information took {time.time() - compile_info_begin}") + + def initialize_utt_fsts(self) -> None: """ - Performs validation of the corpus + Construct utterance FSTs """ - self.analyze_setup() - if self.ignore_acoustics: - print("Skipping test alignments.") - return - if not isinstance(self.trainer, PretrainedAligner): - self.trainer.init_training( - self.trainer.train_type, self.temp_directory, self.corpus, self.dictionary, None - ) - self.trainer.train() - self.trainer.align(None) - self.analyze_unaligned_utterances() - if self.test_transcriptions: - self.test_utterance_transcriptions() + self.output_utt_fsts() - def test_utterance_transcriptions(self): + def test_utterance_transcriptions(self) -> None: """ Tests utterance transcriptions with simple unigram models based on the utterance text and frequent words in the corpus @@ -444,61 +800,306 @@ def test_utterance_transcriptions(self): """ self.logger.info("Checking utterance transcriptions...") - model_directory = self.trainer.align_directory - log_directory = os.path.join(model_directory, "log") - try: - jobs = [x.compile_utterance_train_graphs_arguments(self) for x in self.corpus.jobs] - if self.trainer.feature_config.use_mp: - run_mp(compile_utterance_train_graphs_func, jobs, log_directory) + jobs = self.compile_utterance_train_graphs_arguments() + if self.use_mp: + run_mp(compile_utterance_train_graphs_func, jobs, self.working_log_directory) else: - run_non_mp(compile_utterance_train_graphs_func, jobs, log_directory) + run_non_mp(compile_utterance_train_graphs_func, jobs, self.working_log_directory) self.logger.info("Utterance FSTs compiled!") self.logger.info("Decoding utterances (this will take some time)...") - jobs = [x.test_utterances_arguments(self) for x in self.corpus.jobs] - if self.trainer.feature_config.use_mp: - run_mp(test_utterances_func, jobs, log_directory) + jobs = self.test_utterances_arguments() + if self.use_mp: + run_mp(test_utterances_func, jobs, self.working_log_directory) else: - run_non_mp(test_utterances_func, jobs, log_directory) + run_non_mp(test_utterances_func, jobs, self.working_log_directory) self.logger.info("Finished decoding utterances!") errors = {} for job in jobs: for dict_name in job.dictionaries: - word_mapping = self.dictionary.dictionary_mapping[ - dict_name - ].reversed_word_mapping + word_mapping = self.dictionary_mapping[dict_name].reversed_word_mapping aligned_int = load_scp(job.out_int_paths[dict_name]) for utt, line in sorted(aligned_int.items()): text = [] for t in line: text.append(word_mapping[int(t)]) - ref_text = self.corpus.utterances[utt].text.split() + ref_text = self.utterances[utt].text.split() edits = edit_distance(text, ref_text) if edits: errors[utt] = (edits, ref_text, text) if not errors: - message = "There were no utterances with transcription issues." + + print( + f"There were {self.printer.colorize('no', 'green')} utterances with transcription issues." + ) else: - out_path = os.path.join(self.corpus.output_directory, "transcription_problems.csv") + out_path = os.path.join(self.output_directory, "transcription_problems.csv") with open(out_path, "w") as problemf: problemf.write("Utterance,Edits,Reference,Decoded\n") for utt, (edits, ref_text, text) in sorted( errors.items(), key=lambda x: -1 * (len(x[1][1]) + len(x[1][2])) ): problemf.write(f"{utt},{edits},{' '.join(ref_text)},{' '.join(text)}\n") - message = ( - f"There were {len(errors)} of {self.corpus.num_utterances} utterances with at least one transcription issue. " - f"Please see the outputted csv file {out_path}." + print( + f"There were {self.printer.colorize(len(errors), 'red')} of {self.printer.colorize(self.num_utterances, 'bright')} utterances with at least one transcription issue. " + f"Please see {self.printer.colorize(out_path, 'bright')} for a list." ) - self.logger.info(self.transcription_analysis_template.format(message)) + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + +class TrainingValidator(TrainableAligner, ValidationMixin): + """ + Validator class for checking whether a corpus and a dictionary will work together + for training + + See Also + -------- + :class:`~montreal_forced_aligner.acoustic_modeling.trainer.TrainableAligner` + For training configuration + :class:`~montreal_forced_aligner.validator.ValidationMixin` + For validation parameters + + Attributes + ---------- + training_configs: dict[str, :class:`~montreal_forced_aligner.acoustic_modeling.monophone.MonophoneTrainer`] + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.training_configs = {} + self.add_config("monophone", {}) + + @classmethod + def parse_parameters( + cls, + config_path: Optional[str] = None, + args: Optional[Namespace] = None, + unknown_args: Optional[list[str]] = None, + ) -> MetaDict: + + """ + Parse parameters for validation from a config path or command-line arguments + + Parameters + ---------- + config_path: str + Config path + args: :class:`~argparse.Namespace` + Command-line arguments from argparse + unknown_args: list[str], optional + Extra command-line arguments + + Returns + ------- + dict[str, Any] + Configuration parameters + """ + global_params = {} + training_params = [] + if config_path: + with open(config_path, "r", encoding="utf8") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + training_params = [] + for k, v in data.items(): + if k == "training": + for t in v: + for k2, v2 in t.items(): + if "features" in v2: + global_params.update(v2["features"]) + del v2["features"] + training_params.append((k2, v2)) + elif k == "features": + if "type" in v: + v["feature_type"] = v["type"] + del v["type"] + global_params.update(v) + else: + global_params[k] = v + if not training_params: + raise ConfigError(f"No 'training' block found in {config_path}") + else: # default training configuration + training_params.append(("monophone", {})) + if training_params: + if training_params[0][0] != "monophone": + raise ConfigError("The first round of training must be monophone.") + global_params["training_configuration"] = training_params + global_params.update(cls.parse_args(args, unknown_args)) + return global_params + + def setup(self): + """ + Set up the corpus and validator + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + if self.initialized: + return + try: + self.dictionary_setup() + self._load_corpus() + self.set_lexicon_word_set(self.corpus_word_set) + self.write_lexicon_information() + + for speaker in self.speakers.values(): + speaker.set_dictionary(self.get_dictionary(speaker.name)) + self.initialize_jobs() + self.write_corpus_information() + self.create_corpus_split() + if self.test_transcriptions: + self.write_lexicon_information(write_disambiguation=True) + if self.ignore_acoustics: + self.logger.info("Skipping acoustic feature generation") + else: + self.generate_features() + + if self.test_transcriptions: + self.initialize_utt_fsts() + self.initialized = True + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + + def validate(self): + """ + Performs validation of the corpus + """ + self.setup() + self.analyze_setup() + if self.ignore_acoustics: + print("Skipping test alignments.") + return + self._print_header("Training") + self.train(True) + if self.test_transcriptions: + self._print_header("Test transcriptions") + self.test_utterance_transcriptions() + + +class PretrainedValidator(PretrainedAligner, ValidationMixin): + """ + Validator class for checking whether a corpus, a dictionary, and + an acoustic model will work together for alignment + + See Also + -------- + :class:`~montreal_forced_aligner.alignment.pretrained.PretrainedAligner` + For alignment configuration + :class:`~montreal_forced_aligner.validator.ValidationMixin` + For validation parameters + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def setup(self): + """ + Set up the corpus and validator + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + if self.initialized: + return + try: + self.dictionary_setup() + self._load_corpus() + self.set_lexicon_word_set(self.corpus_word_set) + self.write_lexicon_information() + + for speaker in self.speakers.values(): + speaker.set_dictionary(self.get_dictionary(speaker.name)) + self.initialize_jobs() + self.write_corpus_information() + self.create_corpus_split() + if self.test_transcriptions: + self.write_lexicon_information(write_disambiguation=True) + if self.ignore_acoustics: + self.logger.info("Skipping acoustic feature generation") + else: + self.generate_features() + self.acoustic_model.validate(self) + self.acoustic_model.export_model(self.working_directory) + self.acoustic_model.log_details(self.logger) + if self.test_transcriptions: + self.write_lexicon_information(write_disambiguation=True) + if self.ignore_acoustics: + self.logger.info("Skipping acoustic feature generation") + else: + self.generate_features() + if self.test_transcriptions: + self.initialize_utt_fsts() + self.initialized = True except Exception as e: if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs, self.logger) - e.update_log_file(self.logger.handlers[0].baseFilename) + e.update_log_file(self.logger) raise + + def align(self) -> None: + """ + Validate alignment + + """ + done_path = os.path.join(self.working_directory, "done") + dirty_path = os.path.join(self.working_directory, "dirty") + if os.path.exists(done_path): + self.logger.debug("Alignment already done, skipping.") + return + try: + log_dir = os.path.join(self.working_directory, "log") + os.makedirs(log_dir, exist_ok=True) + self.compile_train_graphs() + + self.logger.debug("Performing first-pass alignment...") + self.speaker_independent = True + self.align_utterances() + if self.uses_speaker_adaptation: + self.logger.debug("Calculating fMLLR for speaker adaptation...") + self.calc_fmllr() + + self.speaker_independent = False + self.logger.debug("Performing second-pass alignment...") + self.align_utterances() + + except Exception as e: + with open(dirty_path, "w"): + pass + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs, self.logger) + e.update_log_file(self.logger) + raise + with open(done_path, "w"): + pass + + def validate(self) -> None: + """ + Performs validation of the corpus + """ + self.setup() + self.analyze_setup() + if self.ignore_acoustics: + print("Skipping test alignments.") + return + self._print_header("Alignment") + self.align() + self.compile_information() + if self.test_transcriptions: + self._print_header("Test transcriptions") + self.test_utterance_transcriptions() diff --git a/pyproject.toml b/pyproject.toml index 9babfc49..624dd22a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,5 +78,8 @@ exclude_lines = [ "def history_save_handler() -> None:", "class ExitHooks(object):", "def main() -> None:", + "if os.path.exists", + "@abstractmethod", + 'if "MFA_ERROR"', ] fail_under = 50 diff --git a/rtd_environment.yml b/rtd_environment.yml index 611790d0..9803d6dc 100644 --- a/rtd_environment.yml +++ b/rtd_environment.yml @@ -2,16 +2,16 @@ name: mfa channels: - conda-forge dependencies: - - python>=3.8 # or 2.7 if you are feeling nostalgic + - python>=3.9 - numpy - librosa - tqdm - requests - colorama - pyyaml + - praatio - pip - pip: - - praatio >= 5.0 - sphinxemoji - sphinxcontrib-autoprogram - git+https://github.com/pydata/pydata-sphinx-theme.git diff --git a/tests/conftest.py b/tests/conftest.py index f7b87b6b..f89ed444 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -# from montreal_forced_aligner.command_line.mfa import fix_path - -# fix_path() - - -if TYPE_CHECKING: - from montreal_forced_aligner.config import FeatureConfig - import os import shutil import pytest import yaml -from montreal_forced_aligner.config import align_yaml_to_config, train_yaml_to_config -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary - @pytest.fixture(scope="session") def test_dir(): @@ -68,14 +54,8 @@ def generated_dir(test_dir): @pytest.fixture(scope="session") def temp_dir(generated_dir): - return os.path.join(generated_dir, "temp") - -@pytest.fixture(scope="session") -def config_dir(generated_dir): - path = os.path.join(generated_dir, "configs") - os.makedirs(path, exist_ok=True) - return path + return os.path.join(generated_dir, "temp") @pytest.fixture(scope="session") @@ -88,18 +68,10 @@ def english_acoustic_model(): @pytest.fixture(scope="session") def english_dictionary(): - from montreal_forced_aligner.command_line.model import download_model, get_pretrained_path + from montreal_forced_aligner.command_line.model import download_model download_model("dictionary", "english") - return get_pretrained_path("dictionary", "english") - - -@pytest.fixture(scope="session") -def basic_dictionary_config(): - from montreal_forced_aligner.config.dictionary_config import DictionaryConfig - - config = DictionaryConfig(debug=True) - return config + return "english" @pytest.fixture(scope="session") @@ -112,10 +84,10 @@ def english_ipa_acoustic_model(): @pytest.fixture(scope="session") def english_us_ipa_dictionary(): - from montreal_forced_aligner.command_line.model import download_model, get_pretrained_path + from montreal_forced_aligner.command_line.model import download_model download_model("dictionary", "english_us_ipa") - return get_pretrained_path("dictionary", "english_us_ipa") + return "english_us_ipa" @pytest.fixture(scope="session") @@ -157,7 +129,7 @@ def transcription_language_model_arpa(language_model_dir, generated_dir): @pytest.fixture(scope="session") def corpus_root_dir(generated_dir): - return os.path.join(generated_dir, "corpus") + return os.path.join(generated_dir, "constructed_test_corpora") @pytest.fixture(scope="session") @@ -170,15 +142,6 @@ def mono_align_model_path(output_model_dir): return os.path.join(output_model_dir, "mono_model.zip") -@pytest.fixture(scope="session") -def default_feature_config() -> FeatureConfig: - from montreal_forced_aligner.config import FeatureConfig - - fc = FeatureConfig() - fc.use_mp = False - return fc - - @pytest.fixture(scope="session") def basic_corpus_dir(corpus_root_dir, wav_dir, lab_dir): path = os.path.join(corpus_root_dir, "basic") @@ -433,15 +396,6 @@ def flac_tg_corpus_dir(corpus_root_dir, wav_dir, textgrid_dir): return path -@pytest.fixture(scope="session") -def flac_transcribe_corpus_dir(corpus_root_dir, wav_dir): - path = os.path.join(corpus_root_dir, "flac_transcribe_corpus") - os.makedirs(path, exist_ok=True) - name = "61-70968-0000" - shutil.copyfile(os.path.join(wav_dir, name + ".flac"), os.path.join(path, name + ".flac")) - return path - - @pytest.fixture(scope="session") def shortsegments_corpus_dir(corpus_root_dir, wav_dir, textgrid_dir): path = os.path.join(corpus_root_dir, "short_segments") @@ -454,18 +408,6 @@ def shortsegments_corpus_dir(corpus_root_dir, wav_dir, textgrid_dir): return path -@pytest.fixture(scope="session") -def vietnamese_corpus_dir(corpus_root_dir, wav_dir, textgrid_dir): - path = os.path.join(corpus_root_dir, "vietnamese") - os.makedirs(path, exist_ok=True) - name = "vietnamese" - shutil.copyfile(os.path.join(wav_dir, "dummy.wav"), os.path.join(path, name + ".wav")) - shutil.copyfile( - os.path.join(textgrid_dir, name + ".TextGrid"), os.path.join(path, name + ".TextGrid") - ) - return path - - @pytest.fixture(scope="session") def dict_dir(test_dir): return os.path.join(test_dir, "dictionaries") @@ -491,66 +433,6 @@ def xsampa_dict_path(dict_dir): return os.path.join(dict_dir, "xsampa.txt") -@pytest.fixture(scope="session") -def expected_dict_path(dict_dir): - return os.path.join(dict_dir, "expected") - - -@pytest.fixture(scope="session") -def basic_topo_path(expected_dict_path): - return os.path.join(expected_dict_path, "topo") - - -@pytest.fixture(scope="session") -def basic_graphemes_path(expected_dict_path): - return os.path.join(expected_dict_path, "graphemes.txt") - - -@pytest.fixture(scope="session") -def basic_phone_map_path(expected_dict_path): - return os.path.join(expected_dict_path, "phone_map.txt") - - -@pytest.fixture(scope="session") -def basic_phones_path(expected_dict_path): - return os.path.join(expected_dict_path, "phones.txt") - - -@pytest.fixture(scope="session") -def basic_words_path(expected_dict_path): - return os.path.join(expected_dict_path, "words.txt") - - -@pytest.fixture(scope="session") -def basic_rootsint_path(expected_dict_path): - return os.path.join(expected_dict_path, "roots.int") - - -@pytest.fixture(scope="session") -def basic_rootstxt_path(expected_dict_path): - return os.path.join(expected_dict_path, "roots.txt") - - -@pytest.fixture(scope="session") -def basic_setsint_path(expected_dict_path): - return os.path.join(expected_dict_path, "sets.int") - - -@pytest.fixture(scope="session") -def basic_setstxt_path(expected_dict_path): - return os.path.join(expected_dict_path, "sets.txt") - - -@pytest.fixture(scope="session") -def basic_word_boundaryint_path(expected_dict_path): - return os.path.join(expected_dict_path, "word_boundary.int") - - -@pytest.fixture(scope="session") -def basic_word_boundarytxt_path(expected_dict_path): - return os.path.join(expected_dict_path, "word_boundary.txt") - - @pytest.fixture(scope="session") def sick_dict_path(dict_dir): return os.path.join(dict_dir, "sick.txt") @@ -571,49 +453,13 @@ def speaker_dictionary_path(sick_dict_path, acoustic_dict_path, generated_dir): @pytest.fixture(scope="session") -def acoustic_corpus_wav_path(basic_dir): - return os.path.join(basic_dir, "acoustic_corpus.wav") - - -@pytest.fixture(scope="session") -def acoustic_corpus_lab_path(basic_dir): - return os.path.join(basic_dir, "acoustic_corpus.lab") - - -@pytest.fixture(scope="session") -def michael_corpus_lab_path(basic_dir): - return os.path.join(basic_dir, "michael_corpus.lab") - - -@pytest.fixture(scope="session") -def output_directory(basic_dir): - return os.path.join(basic_dir, "output") - - -@pytest.fixture(scope="session") -def acoustic_corpus_textgrid_path(basic_dir): - return os.path.join(basic_dir, "acoustic_corpus.TextGrid") - - -@pytest.fixture(scope="session") -def sick_dict(sick_dict_path, generated_dir, basic_dictionary_config): - output_directory = os.path.join(generated_dir, "sickcorpus") - - dictionary = MultispeakerDictionary(sick_dict_path, output_directory, basic_dictionary_config) - dictionary.write() - return dictionary - - -@pytest.fixture(scope="session") -def sick_corpus(basic_corpus_dir, generated_dir, basic_dictionary_config): - output_directory = os.path.join(generated_dir, "sickcorpus") - corpus = Corpus(basic_corpus_dir, output_directory, basic_dictionary_config, num_jobs=2) - return corpus +def sick_dict(sick_dict_path, generated_dir): + return sick_dict_path @pytest.fixture(scope="session") -def textgrid_directory(test_dir): - return os.path.join(test_dir, "textgrid") +def sick_corpus(basic_corpus_dir): + return basic_corpus_dir @pytest.fixture(scope="session") @@ -631,20 +477,6 @@ def ivector_output_model_path(generated_dir): return os.path.join(generated_dir, "ivector_output_model.zip") -@pytest.fixture(scope="session") -def training_dict_path(test_dir): - return os.path.join( - test_dir, - "dictionaries", - "chinese_dict.txt", - ) - - -@pytest.fixture(scope="session") -def g2p_model_path(generated_dir): - return os.path.join(generated_dir, "pinyin_g2p.zip") - - @pytest.fixture(scope="session") def sick_g2p_model_path(generated_dir): return os.path.join(generated_dir, "sick_g2p.zip") @@ -660,113 +492,83 @@ def orth_sick_output(generated_dir): return os.path.join(generated_dir, "orth_sick.txt") -@pytest.fixture(scope="session") -def example_output_model_path(generated_dir): - return os.path.join(generated_dir, "example_output_model.zip") - - -@pytest.fixture(scope="session") -def KO_dict(test_dir): - return os.path.join(test_dir, "dictionaries", "KO_dict.txt") - - @pytest.fixture(scope="session") def config_directory(test_dir): return os.path.join(test_dir, "configs") @pytest.fixture(scope="session") -def basic_train_config(config_directory): +def basic_train_config_path(config_directory): return os.path.join(config_directory, "basic_train_config.yaml") @pytest.fixture(scope="session") -def transcribe_config(config_directory): +def transcribe_config_path(config_directory): return os.path.join(config_directory, "transcribe.yaml") @pytest.fixture(scope="session") -def g2p_config(config_directory): +def g2p_config_path(config_directory): return os.path.join(config_directory, "g2p_config.yaml") @pytest.fixture(scope="session") -def train_g2p_config(config_directory): +def train_g2p_config_path(config_directory): return os.path.join(config_directory, "train_g2p_config.yaml") @pytest.fixture(scope="session") -def basic_train_lm_config(config_directory): +def basic_train_lm_config_path(config_directory): return os.path.join(config_directory, "basic_train_lm.yaml") @pytest.fixture(scope="session") -def different_punctuation_config(config_directory): +def different_punctuation_config_path(config_directory): return os.path.join(config_directory, "different_punctuation_config.yaml") @pytest.fixture(scope="session") -def basic_align_config(config_directory): +def basic_align_config_path(config_directory): return os.path.join(config_directory, "basic_align_config.yaml") @pytest.fixture(scope="session") -def basic_segment_config(config_directory): +def basic_segment_config_path(config_directory): return os.path.join(config_directory, "basic_segment_config.yaml") @pytest.fixture(scope="session") -def train_ivector_config(config_directory): +def train_ivector_config_path(config_directory): return os.path.join(config_directory, "ivector_train.yaml") -@pytest.fixture(scope="session") -def mono_train_config_path(config_directory): - return os.path.join(config_directory, "mono_train.yaml") - - -@pytest.fixture(scope="session") -def mono_train_config(mono_train_config_path): - return train_yaml_to_config(mono_train_config_path) - - @pytest.fixture(scope="session") def mono_align_config_path(config_directory): return os.path.join(config_directory, "mono_align.yaml") @pytest.fixture(scope="session") -def mono_align_config(mono_align_config_path): - return align_yaml_to_config(mono_align_config_path)[0] - - -@pytest.fixture(scope="session") -def tri_train_config(config_directory): - return train_yaml_to_config(os.path.join(config_directory, "tri_train.yaml")) - - -@pytest.fixture(scope="session") -def lda_train_config(config_directory): - return train_yaml_to_config(os.path.join(config_directory, "lda_train.yaml")) +def mono_train_config_path(config_directory): + return os.path.join(config_directory, "mono_train.yaml") @pytest.fixture(scope="session") -def sat_train_config(config_directory): - return train_yaml_to_config(os.path.join(config_directory, "sat_train.yaml")) +def tri_train_config_path(config_directory): + return os.path.join(config_directory, "tri_train.yaml") @pytest.fixture(scope="session") -def lda_sat_train_config(config_directory): - return train_yaml_to_config(os.path.join(config_directory, "lda_sat_train.yaml")) +def lda_train_config_path(config_directory): + return os.path.join(config_directory, "lda_train.yaml") @pytest.fixture(scope="session") -def ivector_train_config(config_directory): - return train_yaml_to_config(os.path.join(config_directory, "ivector_train.yaml")) +def sat_train_config_path(config_directory): + return os.path.join(config_directory, "sat_train.yaml") @pytest.fixture(scope="session") -def multispeaker_dictionary_config(generated_dir, sick_dict_path): +def multispeaker_dictionary_config_path(generated_dir, sick_dict_path): path = os.path.join(generated_dir, "multispeaker_dictionary.yaml") with open(path, "w", encoding="utf8") as f: yaml.safe_dump({"default": "english", "michael": sick_dict_path}, f) @@ -781,3 +583,8 @@ def ipa_speaker_dict_path(generated_dir, english_uk_ipa_dictionary, english_us_i {"default": english_us_ipa_dictionary, "speaker": english_uk_ipa_dictionary}, f ) return path + + +@pytest.fixture(scope="session") +def test_align_config(): + return {"beam": 100, "retry_beam": 400} diff --git a/tests/data/configs/basic_train_lm.yaml b/tests/data/configs/basic_train_lm.yaml index 9254584b..bdfc192a 100644 --- a/tests/data/configs/basic_train_lm.yaml +++ b/tests/data/configs/basic_train_lm.yaml @@ -1,5 +1,4 @@ order: 3 method: kneser_ney -prune: true prune_thresh_small: 0.0000003 prune_thresh_medium: 0.0000001 diff --git a/tests/data/configs/ivector_train.yaml b/tests/data/configs/ivector_train.yaml index b4747f2a..8880db5d 100644 --- a/tests/data/configs/ivector_train.yaml +++ b/tests/data/configs/ivector_train.yaml @@ -6,7 +6,9 @@ features: frame_shift: 10 training: + - dubm: + num_iterations_init: 4 + num_iterations: 2 - ivector: num_iterations: 2 gaussian_min_count: 2 - ubm_num_iterations_init: 4 diff --git a/tests/data/configs/lda_train.yaml b/tests/data/configs/lda_train.yaml index af01ba6c..ccf928b6 100644 --- a/tests/data/configs/lda_train.yaml +++ b/tests/data/configs/lda_train.yaml @@ -15,7 +15,7 @@ training: subset: 1000 - lda: - num_iterations: 2 + num_iterations: 15 num_leaves: 500 max_gaussians: 4000 subset: 1000 diff --git a/tests/data/configs/mono_train.yaml b/tests/data/configs/mono_train.yaml index 908e77c3..1702d153 100644 --- a/tests/data/configs/mono_train.yaml +++ b/tests/data/configs/mono_train.yaml @@ -1,4 +1,4 @@ -beam: 100 +beam: 10 retry_beam: 400 use_mp: false @@ -10,6 +10,6 @@ features: training: - monophone: - num_iterations: 3 + num_iterations: 10 max_gaussians: 500 subset: 1000 diff --git a/tests/data/configs/sat_train.yaml b/tests/data/configs/sat_train.yaml index d7fddf41..b00d8295 100644 --- a/tests/data/configs/sat_train.yaml +++ b/tests/data/configs/sat_train.yaml @@ -24,7 +24,7 @@ training: power: 0.25 - sat: - num_iterations: 2 + num_iterations: 15 num_leaves: 2000 max_gaussians: 10000 power: 0.2 diff --git a/tests/data/configs/train_g2p_config.yaml b/tests/data/configs/train_g2p_config.yaml index 47336ada..a5203920 100644 --- a/tests/data/configs/train_g2p_config.yaml +++ b/tests/data/configs/train_g2p_config.yaml @@ -9,7 +9,7 @@ seed: 1917 delta: 0.0009765 lr: 1.0 batch_size: 200 -max_iterations: 10 +num_iterations: 10 smoothing_method: "kneser_ney" pruning_method: "relative_entropy" model_size: 1000000 diff --git a/tests/data/dictionaries/sick.txt b/tests/data/dictionaries/sick.txt index 865ab0bb..af563d8a 100755 --- a/tests/data/dictionaries/sick.txt +++ b/tests/data/dictionaries/sick.txt @@ -35,7 +35,7 @@ in ih n intensity ih n t eh n s ih t iy saying s ey ih ng words w er d z -here's h iy r z +here's hh iy r z more m ao r um ah m that dh ae t diff --git a/tests/test_abc.py b/tests/test_abc.py new file mode 100644 index 00000000..5a604c21 --- /dev/null +++ b/tests/test_abc.py @@ -0,0 +1,15 @@ +from montreal_forced_aligner.abc import MfaWorker, TrainerMixin +from montreal_forced_aligner.acoustic_modeling import SatTrainer, TrainableAligner +from montreal_forced_aligner.alignment import AlignMixin + + +def test_typing(sick_corpus, sick_dict, temp_dir): + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, dictionary_path=sick_dict, temporary_directory=temp_dir + ) + trainer = SatTrainer(identifier="sat", worker=am_trainer) + assert type(trainer).__name__ == "SatTrainer" + assert isinstance(trainer, TrainerMixin) + assert isinstance(trainer, AlignMixin) + assert isinstance(trainer, MfaWorker) + assert isinstance(am_trainer, MfaWorker) diff --git a/tests/test_acoustic_modeling.py b/tests/test_acoustic_modeling.py new file mode 100644 index 00000000..52bee000 --- /dev/null +++ b/tests/test_acoustic_modeling.py @@ -0,0 +1,101 @@ +import argparse +import os +import shutil + +from montreal_forced_aligner.acoustic_modeling.trainer import TrainableAligner +from montreal_forced_aligner.alignment import PretrainedAligner + + +def test_trainer(sick_dict, sick_corpus, generated_dir): + data_directory = os.path.join(generated_dir, "temp", "train_test") + a = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=data_directory, + ) + assert a.final_identifier == "sat_1" + assert a.training_configs[a.final_identifier].subset == 0 + assert a.training_configs[a.final_identifier].num_leaves == 4200 + assert a.training_configs[a.final_identifier].max_gaussians == 40000 + + +def test_sick_mono( + sick_dict, + sick_corpus, + generated_dir, + mono_train_config_path, + mono_align_model_path, + mono_output_directory, +): + data_directory = os.path.join(generated_dir, "temp", "mono_train_test") + shutil.rmtree(data_directory, ignore_errors=True) + args = argparse.Namespace(use_mp=True, debug=False, verbose=True) + a = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=data_directory, + **TrainableAligner.parse_parameters(mono_train_config_path, args=args) + ) + a.train() + a.export_model(mono_align_model_path) + + data_directory = os.path.join(generated_dir, "temp", "mono_align_test") + shutil.rmtree(data_directory, ignore_errors=True) + a = PretrainedAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + acoustic_model_path=mono_align_model_path, + temporary_directory=data_directory, + **PretrainedAligner.parse_parameters(args=args) + ) + a.align() + a.export_files(mono_output_directory) + + +def test_sick_tri(sick_dict, sick_corpus, generated_dir, tri_train_config_path): + data_directory = os.path.join(generated_dir, "temp", "tri_test") + shutil.rmtree(data_directory, ignore_errors=True) + a = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=data_directory, + debug=True, + verbose=True, + **TrainableAligner.parse_parameters(tri_train_config_path) + ) + a.train() + + +def test_sick_lda(sick_dict, sick_corpus, generated_dir, lda_train_config_path): + data_directory = os.path.join(generated_dir, "temp", "lda_test") + shutil.rmtree(data_directory, ignore_errors=True) + a = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=data_directory, + debug=True, + verbose=True, + **TrainableAligner.parse_parameters(lda_train_config_path) + ) + a.train() + assert len(a.training_configs[a.final_identifier].realignment_iterations) > 0 + assert len(a.training_configs[a.final_identifier].mllt_iterations) > 1 + + +def test_sick_sat(sick_dict, sick_corpus, generated_dir, sat_train_config_path): + data_directory = os.path.join(generated_dir, "temp", "sat_test") + output_model_path = os.path.join(data_directory, "sat_model.zip") + shutil.rmtree(data_directory, ignore_errors=True) + args = argparse.Namespace(use_mp=True, debug=True, verbose=True) + a = TrainableAligner( + **TrainableAligner.parse_parameters(sat_train_config_path, args=args), + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=data_directory + ) + a.train() + assert len(a.training_configs[a.final_identifier].realignment_iterations) > 0 + assert len(a.training_configs[a.final_identifier].fmllr_iterations) > 1 + a.export_model(output_model_path) + + assert os.path.exists diff --git a/tests/test_aligner.py b/tests/test_aligner.py deleted file mode 100644 index d45f2c34..00000000 --- a/tests/test_aligner.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import shutil - -from montreal_forced_aligner.aligner import PretrainedAligner, TrainableAligner -from montreal_forced_aligner.models import AcousticModel - - -def test_sick_mono( - sick_dict, - sick_corpus, - generated_dir, - mono_train_config, - mono_align_model_path, - mono_align_config, - mono_output_directory, -): - mono_train_config, align_config, dictionary_config = mono_train_config - data_directory = os.path.join(generated_dir, "temp", "mono_train_test") - shutil.rmtree(data_directory, ignore_errors=True) - a = TrainableAligner( - sick_corpus, sick_dict, mono_train_config, align_config, temp_directory=data_directory - ) - a.train() - a.save(mono_align_model_path) - - model = AcousticModel(mono_align_model_path) - data_directory = os.path.join(generated_dir, "temp", "mono_align_test") - shutil.rmtree(data_directory, ignore_errors=True) - mono_align_config.debug = True - a = PretrainedAligner( - sick_corpus, sick_dict, model, mono_align_config, temp_directory=data_directory, debug=True - ) - a.align() - a.export_textgrids(mono_output_directory) - - -def test_sick_tri(sick_dict, sick_corpus, generated_dir, tri_train_config): - tri_train_config, align_config, dictionary_config = tri_train_config - data_directory = os.path.join(generated_dir, "temp", "tri_test") - shutil.rmtree(data_directory, ignore_errors=True) - a = TrainableAligner( - sick_corpus, sick_dict, tri_train_config, align_config, temp_directory=data_directory - ) - a.train() - - -def test_sick_lda(sick_dict, sick_corpus, generated_dir, lda_train_config): - lda_train_config, align_config, dictionary_config = lda_train_config - data_directory = os.path.join(generated_dir, "temp", "lda_test") - shutil.rmtree(data_directory, ignore_errors=True) - a = TrainableAligner( - sick_corpus, sick_dict, lda_train_config, align_config, temp_directory=data_directory - ) - a.train() - - -def test_sick_sat(sick_dict, sick_corpus, generated_dir, sat_train_config): - sat_train_config, align_config, dictionary_config = sat_train_config - data_directory = os.path.join(generated_dir, "temp", "sat_test") - shutil.rmtree(data_directory, ignore_errors=True) - a = TrainableAligner( - sick_corpus, sick_dict, sat_train_config, align_config, temp_directory=data_directory - ) - a.train() - a.export_textgrids(os.path.join(generated_dir, "sick_output")) diff --git a/tests/test_alignment_pretrained.py b/tests/test_alignment_pretrained.py new file mode 100644 index 00000000..5150ac16 --- /dev/null +++ b/tests/test_alignment_pretrained.py @@ -0,0 +1,24 @@ +import os +import shutil + +from montreal_forced_aligner.alignment import PretrainedAligner + + +def test_align_sick( + english_dictionary, english_acoustic_model, basic_corpus_dir, temp_dir, test_align_config +): + a = PretrainedAligner( + corpus_directory=basic_corpus_dir, + dictionary_path=english_dictionary, + acoustic_model_path=english_acoustic_model, + temporary_directory=temp_dir, + debug=True, + verbose=True, + **test_align_config + ) + a.align() + export_directory = os.path.join(temp_dir, "test_align_export") + shutil.rmtree(export_directory, ignore_errors=True) + os.makedirs(export_directory, exist_ok=True) + a.export_files(export_directory) + assert os.path.exists(os.path.join(export_directory, "michael", "acoustic_corpus.TextGrid")) diff --git a/tests/test_commandline_adapt.py b/tests/test_commandline_adapt.py index 90c22db9..322f7c43 100644 --- a/tests/test_commandline_adapt.py +++ b/tests/test_commandline_adapt.py @@ -10,7 +10,7 @@ def test_adapt_basic( generated_dir, english_dictionary, temp_dir, - basic_align_config, + test_align_config, english_acoustic_model, ): adapted_model_path = os.path.join(generated_dir, "basic_adapted.zip") @@ -22,7 +22,6 @@ def test_adapt_basic( adapted_model_path, "-t", temp_dir, - "-q", "--clean", "--debug", ] @@ -37,7 +36,7 @@ def test_adapt_multilingual( ipa_speaker_dict_path, generated_dir, temp_dir, - basic_align_config, + basic_align_config_path, english_acoustic_model, english_ipa_acoustic_model, ): @@ -53,7 +52,7 @@ def test_adapt_multilingual( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", diff --git a/tests/test_commandline_align.py b/tests/test_commandline_align.py index c0ec166b..585a8901 100644 --- a/tests/test_commandline_align.py +++ b/tests/test_commandline_align.py @@ -3,7 +3,8 @@ import pytest from praatio import textgrid as tgio -from montreal_forced_aligner.command_line.align import load_basic_align, run_align_corpus +from montreal_forced_aligner.alignment.pretrained import PretrainedAligner +from montreal_forced_aligner.command_line.align import run_align_corpus from montreal_forced_aligner.command_line.mfa import parser from montreal_forced_aligner.exceptions import PronunciationAcousticMismatchError @@ -40,15 +41,12 @@ def test_align_arguments( "-q", "--clean", "--debug", - "--disable_sat", + "--uses_speaker_adaptation", + "False", ] args, unknown_args = parser.parse_known_args(command) - print(args, unknown_args) - align_config, dictionary_config = load_basic_align() - assert not align_config.disable_sat - if unknown_args: - align_config.update_from_unknown_args(unknown_args) - assert align_config.disable_sat + params = PretrainedAligner.parse_parameters(args=args, unknown_args=unknown_args) + assert not params["uses_speaker_adaptation"] # @pytest.mark.skip(reason='Optimization') @@ -58,7 +56,7 @@ def test_align_basic( generated_dir, english_dictionary, temp_dir, - basic_align_config, + basic_align_config_path, english_acoustic_model, ): command = [ @@ -70,7 +68,7 @@ def test_align_basic( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", @@ -88,7 +86,7 @@ def test_align_basic( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", @@ -112,7 +110,7 @@ def test_align_basic( assert os.path.exists(path) mod_times[path] = os.stat(path).st_mtime - align_temp_dir = os.path.join(temp_dir, "basic", "align") + align_temp_dir = os.path.join(temp_dir, "basic_pretrained_aligner", "pretrained_aligner") assert os.path.exists(align_temp_dir) backup_textgrid_dir = os.path.join(align_temp_dir, "textgrids") @@ -127,7 +125,7 @@ def test_align_basic( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--debug", "--disable_mp", @@ -150,7 +148,7 @@ def test_align_basic( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--disable_textgrid_cleanup", "--clean", @@ -170,7 +168,7 @@ def test_align_multilingual( english_uk_ipa_dictionary, generated_dir, temp_dir, - basic_align_config, + basic_align_config_path, english_acoustic_model, english_ipa_acoustic_model, ): @@ -184,7 +182,7 @@ def test_align_multilingual( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", @@ -198,7 +196,7 @@ def test_align_multilingual_speaker_dict( ipa_speaker_dict_path, generated_dir, temp_dir, - basic_align_config, + basic_align_config_path, english_ipa_acoustic_model, ): @@ -211,7 +209,7 @@ def test_align_multilingual_speaker_dict( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", @@ -225,7 +223,7 @@ def test_align_multilingual_tg_speaker_dict( ipa_speaker_dict_path, generated_dir, temp_dir, - basic_align_config, + basic_align_config_path, english_ipa_acoustic_model, ): @@ -238,7 +236,7 @@ def test_align_multilingual_tg_speaker_dict( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", @@ -252,7 +250,7 @@ def test_align_split( english_us_ipa_dictionary, generated_dir, temp_dir, - basic_align_config, + basic_align_config_path, english_acoustic_model, english_ipa_acoustic_model, ): @@ -266,7 +264,7 @@ def test_align_split( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", @@ -283,7 +281,7 @@ def test_align_stereo( generated_dir, english_dictionary, temp_dir, - basic_align_config, + basic_align_config_path, english_acoustic_model, ): output_dir = os.path.join(generated_dir, "stereo_output") @@ -296,7 +294,7 @@ def test_align_stereo( "-t", temp_dir, "--config_path", - basic_align_config, + basic_align_config_path, "-q", "--clean", "--debug", diff --git a/tests/test_commandline_configure.py b/tests/test_commandline_configure.py index bde41a3d..7626ab9c 100644 --- a/tests/test_commandline_configure.py +++ b/tests/test_commandline_configure.py @@ -2,8 +2,8 @@ from montreal_forced_aligner.command_line.mfa import create_parser from montreal_forced_aligner.config import ( - TEMP_DIR, generate_config_path, + get_temporary_directory, load_global_config, update_global_config, ) @@ -15,7 +15,7 @@ def test_configure( sick_dict_path, generated_dir, english_dictionary, - basic_align_config, + basic_align_config_path, english_acoustic_model, ): path = generate_config_path() @@ -33,7 +33,7 @@ def test_configure( "num_jobs": 3, "blas_num_threads": 1, "use_mp": True, - "temp_directory": TEMP_DIR, + "temporary_directory": get_temporary_directory(), } parser = create_parser() command = [ @@ -63,7 +63,7 @@ def test_configure( "num_jobs": 10, "blas_num_threads": 1, "use_mp": False, - "temp_directory": temp_dir, + "temporary_directory": temp_dir, } command = ["configure", "--never_clean", "--enable_mp", "--never_verbose"] parser = create_parser() @@ -82,7 +82,7 @@ def test_configure( "num_jobs": 10, "blas_num_threads": 1, "use_mp": True, - "temp_directory": temp_dir, + "temporary_directory": temp_dir, } parser = create_parser() @@ -93,16 +93,16 @@ def test_configure( "english", os.path.join(generated_dir, "basic_output"), "-t", - TEMP_DIR, - "-c", - basic_align_config, + get_temporary_directory(), + "--config_path", + basic_align_config_path, "-q", "--clean", "-d", ] args, unknown = parser.parse_known_args(command) assert args.num_jobs == 10 - assert args.temp_directory == TEMP_DIR + assert args.temporary_directory == get_temporary_directory() assert args.clean assert not args.disable_mp if os.path.exists(path): diff --git a/tests/test_commandline_create_segments.py b/tests/test_commandline_create_segments.py index 25fbbf4e..fcc43688 100644 --- a/tests/test_commandline_create_segments.py +++ b/tests/test_commandline_create_segments.py @@ -8,7 +8,7 @@ def test_create_segments( basic_corpus_dir, generated_dir, temp_dir, - basic_segment_config, + basic_segment_config_path, ): output_path = os.path.join(generated_dir, "segment_output") command = [ @@ -22,7 +22,7 @@ def test_create_segments( "--debug", "-v", "--config_path", - basic_segment_config, + basic_segment_config_path, ] args, unknown = parser.parse_known_args(command) run_create_segments(args) diff --git a/tests/test_commandline_g2p.py b/tests/test_commandline_g2p.py index a54eb8fa..05c3d175 100644 --- a/tests/test_commandline_g2p.py +++ b/tests/test_commandline_g2p.py @@ -5,14 +5,11 @@ from montreal_forced_aligner.command_line.g2p import run_g2p from montreal_forced_aligner.command_line.mfa import parser from montreal_forced_aligner.command_line.train_g2p import run_train_g2p -from montreal_forced_aligner.dictionary import PronunciationDictionary +from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionary from montreal_forced_aligner.g2p.generator import G2P_DISABLED -from montreal_forced_aligner.models import DictionaryModel -def test_generate_pretrained( - english_g2p_model, basic_corpus_dir, temp_dir, generated_dir, basic_dictionary_config -): +def test_generate_pretrained(english_g2p_model, basic_corpus_dir, temp_dir, generated_dir): if G2P_DISABLED: pytest.skip("No Pynini found") output_path = os.path.join(generated_dir, "g2p_out.txt") @@ -33,11 +30,12 @@ def test_generate_pretrained( args, unknown = parser.parse_known_args(command) run_g2p(args, unknown) assert os.path.exists(output_path) - d = PronunciationDictionary(DictionaryModel(output_path), temp_dir, basic_dictionary_config) + d = PronunciationDictionary(output_path, temporary_directory=temp_dir) + assert len(d.words) > 0 -def test_train_g2p(sick_dict_path, sick_g2p_model_path, temp_dir, train_g2p_config): +def test_train_g2p(sick_dict_path, sick_g2p_model_path, temp_dir, train_g2p_config_path): if G2P_DISABLED: pytest.skip("No Pynini found") command = [ @@ -45,13 +43,13 @@ def test_train_g2p(sick_dict_path, sick_g2p_model_path, temp_dir, train_g2p_conf sick_dict_path, sick_g2p_model_path, "-t", - temp_dir, + os.path.join(temp_dir, "test_train_g2p"), "-q", "--clean", "--debug", "--validate", "--config_path", - train_g2p_config, + train_g2p_config_path, ] args, unknown = parser.parse_known_args(command) run_train_g2p(args, unknown) @@ -63,8 +61,7 @@ def test_generate_dict( sick_g2p_model_path, g2p_sick_output, temp_dir, - g2p_config, - basic_dictionary_config, + g2p_config_path, ): if G2P_DISABLED: pytest.skip("No Pynini found") @@ -79,14 +76,12 @@ def test_generate_dict( "--clean", "--debug", "--config_path", - g2p_config, + g2p_config_path, ] args, unknown = parser.parse_known_args(command) run_g2p(args, unknown) assert os.path.exists(g2p_sick_output) - d = PronunciationDictionary( - DictionaryModel(g2p_sick_output), temp_dir, basic_dictionary_config - ) + d = PronunciationDictionary(dictionary_path=g2p_sick_output, temporary_directory=temp_dir) assert len(d.words) > 0 @@ -95,8 +90,7 @@ def test_generate_dict_text_only( sick_g2p_model_path, g2p_sick_output, temp_dir, - g2p_config, - basic_dictionary_config, + g2p_config_path, ): if G2P_DISABLED: pytest.skip("No Pynini found") @@ -112,20 +106,16 @@ def test_generate_dict_text_only( "--clean", "--debug", "--config_path", - g2p_config, + g2p_config_path, ] args, unknown = parser.parse_known_args(command) run_g2p(args, unknown) assert os.path.exists(g2p_sick_output) - d = PronunciationDictionary( - DictionaryModel(g2p_sick_output), temp_dir, basic_dictionary_config - ) + d = PronunciationDictionary(dictionary_path=g2p_sick_output, temporary_directory=temp_dir) assert len(d.words) > 0 -def test_generate_orthography_dict( - basic_corpus_dir, orth_sick_output, temp_dir, basic_dictionary_config -): +def test_generate_orthography_dict(basic_corpus_dir, orth_sick_output, temp_dir): if G2P_DISABLED: pytest.skip("No Pynini found") command = [ @@ -143,7 +133,5 @@ def test_generate_orthography_dict( args, unknown = parser.parse_known_args(command) run_g2p(args, unknown) assert os.path.exists(orth_sick_output) - d = PronunciationDictionary( - DictionaryModel(orth_sick_output), temp_dir, basic_dictionary_config - ) + d = PronunciationDictionary(dictionary_path=orth_sick_output, temporary_directory=temp_dir) assert len(d.words) > 0 diff --git a/tests/test_commandline_history.py b/tests/test_commandline_history.py new file mode 100644 index 00000000..c0567dc4 --- /dev/null +++ b/tests/test_commandline_history.py @@ -0,0 +1,27 @@ +from montreal_forced_aligner.command_line.mfa import parser, print_history + + +def test_mfa_history( + multilingual_ipa_tg_corpus_dir, english_ipa_acoustic_model, english_us_ipa_dictionary, temp_dir +): + + command = ["history", "--depth", "60"] + args, unknown = parser.parse_known_args(command) + print_history(args) + + command = ["history"] + args, unknown = parser.parse_known_args(command) + print_history(args) + + +def test_mfa_history_verbose( + multilingual_ipa_tg_corpus_dir, english_ipa_acoustic_model, english_us_ipa_dictionary, temp_dir +): + + command = ["history", "-v", "--depth", "60"] + args, unknown = parser.parse_known_args(command) + print_history(args) + + command = ["history", "-v"] + args, unknown = parser.parse_known_args(command) + print_history(args) diff --git a/tests/test_commandline_lm.py b/tests/test_commandline_lm.py index 97a4208c..6915018e 100644 --- a/tests/test_commandline_lm.py +++ b/tests/test_commandline_lm.py @@ -7,9 +7,10 @@ from montreal_forced_aligner.command_line.train_lm import run_train_lm -def test_train_lm(basic_corpus_dir, temp_dir, generated_dir, basic_train_lm_config): +def test_train_lm(basic_corpus_dir, temp_dir, generated_dir, basic_train_lm_config_path): if sys.platform == "win32": pytest.skip("LM training not supported on Windows.") + temp_dir = os.path.join(temp_dir, "train_lm") command = [ "train_lm", basic_corpus_dir, @@ -17,7 +18,7 @@ def test_train_lm(basic_corpus_dir, temp_dir, generated_dir, basic_train_lm_conf "-t", temp_dir, "--config_path", - basic_train_lm_config, + basic_train_lm_config_path, "-q", "--clean", ] @@ -26,9 +27,10 @@ def test_train_lm(basic_corpus_dir, temp_dir, generated_dir, basic_train_lm_conf assert os.path.exists(args.output_model_path) -def test_train_lm_text(basic_split_dir, temp_dir, generated_dir, basic_train_lm_config): +def test_train_lm_text(basic_split_dir, temp_dir, generated_dir, basic_train_lm_config_path): if sys.platform == "win32": pytest.skip("LM training not supported on Windows.") + temp_dir = os.path.join(temp_dir, "train_lm_text") text_dir = basic_split_dir[1] command = [ "train_lm", @@ -37,7 +39,7 @@ def test_train_lm_text(basic_split_dir, temp_dir, generated_dir, basic_train_lm_ "-t", temp_dir, "--config_path", - basic_train_lm_config, + basic_train_lm_config_path, "-q", "--clean", ] @@ -46,9 +48,12 @@ def test_train_lm_text(basic_split_dir, temp_dir, generated_dir, basic_train_lm_ assert os.path.exists(args.output_model_path) -def test_train_lm_text_no_mp(basic_split_dir, temp_dir, generated_dir, basic_train_lm_config): +def test_train_lm_dictionary( + basic_split_dir, basic_dict_path, temp_dir, generated_dir, basic_train_lm_config_path +): if sys.platform == "win32": pytest.skip("LM training not supported on Windows.") + temp_dir = os.path.join(temp_dir, "train_lm_dictionary") text_dir = basic_split_dir[1] command = [ "train_lm", @@ -56,8 +61,52 @@ def test_train_lm_text_no_mp(basic_split_dir, temp_dir, generated_dir, basic_tra os.path.join(generated_dir, "test_basic_lm_split.zip"), "-t", temp_dir, + "--dictionary_path", + basic_dict_path, "--config_path", - basic_train_lm_config, + basic_train_lm_config_path, + "-q", + "--clean", + ] + args, unknown = parser.parse_known_args(command) + run_train_lm(args) + assert os.path.exists(args.output_model_path) + + +def test_train_lm_arpa( + transcription_language_model_arpa, temp_dir, generated_dir, basic_train_lm_config_path +): + if sys.platform == "win32": + pytest.skip("LM training not supported on Windows.") + temp_dir = os.path.join(temp_dir, "train_lm_arpa") + command = [ + "train_lm", + transcription_language_model_arpa, + os.path.join(generated_dir, "test_basic_lm_split.zip"), + "-t", + temp_dir, + "--config_path", + basic_train_lm_config_path, + "-q", + "--clean", + ] + args, unknown = parser.parse_known_args(command) + run_train_lm(args) + assert os.path.exists(args.output_model_path) + + +def test_train_lm_text_no_mp(basic_split_dir, temp_dir, generated_dir, basic_train_lm_config_path): + if sys.platform == "win32": + pytest.skip("LM training not supported on Windows.") + text_dir = basic_split_dir[1] + command = [ + "train_lm", + text_dir, + os.path.join(generated_dir, "test_basic_lm_split.zip"), + "-t", + temp_dir, + "--config_path", + basic_train_lm_config_path, "-q", "--clean", "-j", diff --git a/tests/test_commandline_model.py b/tests/test_commandline_model.py index 7b7bea98..5d94fca4 100644 --- a/tests/test_commandline_model.py +++ b/tests/test_commandline_model.py @@ -6,10 +6,10 @@ from montreal_forced_aligner.command_line.model import ( ModelTypeNotSupportedError, PretrainedModelNotFoundError, - get_pretrained_path, list_downloadable_models, run_model, ) +from montreal_forced_aligner.models import AcousticModel, DictionaryModel, G2PModel class DummyArgs(Namespace): @@ -44,7 +44,7 @@ def test_download(): run_model(args) - assert os.path.exists(get_pretrained_path("acoustic", args.name)) + assert os.path.exists(AcousticModel.get_pretrained_path(args.name)) args = DummyArgs() args.action = "download" @@ -53,7 +53,7 @@ def test_download(): run_model(args) - assert os.path.exists(get_pretrained_path("g2p", args.name)) + assert os.path.exists(G2PModel.get_pretrained_path(args.name)) args = DummyArgs() args.action = "download" @@ -62,7 +62,7 @@ def test_download(): run_model(args) - assert os.path.exists(get_pretrained_path("dictionary", args.name)) + assert os.path.exists(DictionaryModel.get_pretrained_path(args.name)) args = DummyArgs() args.action = "download" diff --git a/tests/test_commandline_train_dict.py b/tests/test_commandline_train_dict.py index cf712280..595660be 100644 --- a/tests/test_commandline_train_dict.py +++ b/tests/test_commandline_train_dict.py @@ -12,9 +12,9 @@ def test_train_dict( transcription_acoustic_model, transcription_language_model, temp_dir, - basic_align_config, + basic_align_config_path, ): - output_path = os.path.join(generated_dir, "trained_dict.txt") + output_path = os.path.join(generated_dir, "trained_dict") command = [ "train_dictionary", basic_corpus_dir, @@ -27,7 +27,7 @@ def test_train_dict( "--clean", "--debug", "--config_path", - basic_align_config, + basic_align_config_path, ] args, unknown = parser.parse_known_args(command) run_train_dictionary(args) diff --git a/tests/test_commandline_train_ivector.py b/tests/test_commandline_train_ivector.py index 498b4251..1c92bc0b 100644 --- a/tests/test_commandline_train_ivector.py +++ b/tests/test_commandline_train_ivector.py @@ -9,22 +9,18 @@ def test_basic_ivector( basic_corpus_dir, generated_dir, - english_dictionary, temp_dir, - train_ivector_config, - english_acoustic_model, + train_ivector_config_path, ivector_output_model_path, ): command = [ "train_ivector", basic_corpus_dir, - english_dictionary, - "english", ivector_output_model_path, "-t", temp_dir, "--config_path", - train_ivector_config, + train_ivector_config_path, "-q", "--clean", "--debug", diff --git a/tests/test_commandline_transcribe.py b/tests/test_commandline_transcribe.py index 3e1dbeae..d02d491c 100644 --- a/tests/test_commandline_transcribe.py +++ b/tests/test_commandline_transcribe.py @@ -15,7 +15,7 @@ def test_transcribe( transcription_acoustic_model, transcription_language_model, temp_dir, - transcribe_config, + transcribe_config_path, ): output_path = os.path.join(generated_dir, "transcribe_test") command = [ @@ -32,7 +32,7 @@ def test_transcribe( "--debug", "-v", "--config_path", - transcribe_config, + transcribe_config_path, ] args, unknown = parser.parse_known_args(command) run_transcribe_corpus(args) @@ -48,11 +48,13 @@ def test_transcribe_arpa( transcription_acoustic_model, transcription_language_model_arpa, temp_dir, - transcribe_config, + transcribe_config_path, ): if sys.platform == "win32": pytest.skip("No LM generation on Windows") + temp_dir = os.path.join(temp_dir, "arpa_test_temp") output_path = os.path.join(generated_dir, "transcribe_test_arpa") + print(transcription_language_model_arpa) command = [ "transcribe", basic_corpus_dir, @@ -67,11 +69,10 @@ def test_transcribe_arpa( "--debug", "-v", "--config_path", - transcribe_config, + transcribe_config_path, ] args, unknown = parser.parse_known_args(command) run_transcribe_corpus(args) - assert os.path.exists(os.path.join(output_path, "michael", "acoustic_corpus.lab")) @@ -82,7 +83,7 @@ def test_transcribe_speaker_dictionaries( generated_dir, transcription_language_model, temp_dir, - transcribe_config, + transcribe_config_path, ): output_path = os.path.join(generated_dir, "transcribe_test") command = [ @@ -98,7 +99,7 @@ def test_transcribe_speaker_dictionaries( "--clean", "--debug", "--config_path", - transcribe_config, + transcribe_config_path, ] args, unknown = parser.parse_known_args(command) run_transcribe_corpus(args) @@ -111,7 +112,7 @@ def test_transcribe_speaker_dictionaries_evaluate( generated_dir, transcription_language_model, temp_dir, - transcribe_config, + transcribe_config_path, ): output_path = os.path.join(generated_dir, "transcribe_test") command = [ @@ -127,7 +128,7 @@ def test_transcribe_speaker_dictionaries_evaluate( "--clean", "--debug", "--config_path", - transcribe_config, + transcribe_config_path, "--evaluate", ] args, unknown = parser.parse_known_args(command) diff --git a/tests/test_commandline_validate.py b/tests/test_commandline_validate.py index 2906d86b..dc1618ed 100644 --- a/tests/test_commandline_validate.py +++ b/tests/test_commandline_validate.py @@ -21,3 +21,27 @@ def test_validate_corpus( ] args, unknown = parser.parse_known_args(command) run_validate_corpus(args) + + +def test_validate_training_corpus( + multilingual_ipa_tg_corpus_dir, + english_ipa_acoustic_model, + english_dictionary, + temp_dir, + mono_train_config_path, +): + + command = [ + "validate", + multilingual_ipa_tg_corpus_dir, + english_dictionary, + "-t", + temp_dir, + "-q", + "--clean", + "--debug", + "--config_path", + mono_train_config_path, + ] + args, unknown = parser.parse_known_args(command) + run_validate_corpus(args) diff --git a/tests/test_config.py b/tests/test_config.py index 93bd3d4c..8b54611b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,23 +2,27 @@ import pytest -from montreal_forced_aligner.config import ( - FeatureConfig, - align_yaml_to_config, - train_yaml_to_config, -) -from montreal_forced_aligner.exceptions import ConfigError -from montreal_forced_aligner.trainers import ( +from montreal_forced_aligner.acoustic_modeling import ( LdaTrainer, MonophoneTrainer, SatTrainer, + TrainableAligner, TriphoneTrainer, ) +from montreal_forced_aligner.alignment import PretrainedAligner +from montreal_forced_aligner.dictionary.mixins import DEFAULT_STRIP_DIACRITICS +from montreal_forced_aligner.exceptions import ConfigError +from montreal_forced_aligner.ivector.trainer import TrainableIvectorExtractor -def test_monophone_config(): - config = MonophoneTrainer(FeatureConfig()) +def test_monophone_config(sick_corpus, sick_dict, temp_dir): + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, dictionary_path=sick_dict, temporary_directory=temp_dir + ) + config = MonophoneTrainer(identifier="mono", worker=am_trainer) + config.compute_calculated_properties() assert config.realignment_iterations == [ + 0, 1, 2, 3, @@ -41,79 +45,156 @@ def test_monophone_config(): 35, 38, ] + am_trainer.cleanup() -def test_triphone_config(): - config = TriphoneTrainer(FeatureConfig()) +def test_triphone_config(sick_corpus, sick_dict, temp_dir): + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, dictionary_path=sick_dict, temporary_directory=temp_dir + ) + config = TriphoneTrainer(identifier="tri", worker=am_trainer) + config.compute_calculated_properties() assert config.realignment_iterations == [10, 20, 30] + am_trainer.cleanup() -def test_lda_mllt_config(): - config = LdaTrainer(FeatureConfig()) - assert config.mllt_iterations == [2, 4, 6, 16] +def test_lda_mllt_config(sick_corpus, sick_dict, temp_dir): + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, dictionary_path=sick_dict, temporary_directory=temp_dir + ) - -def test_load_align(config_directory, mono_align_config_path): - _ = align_yaml_to_config(mono_align_config_path) + assert am_trainer.beam == 10 + assert am_trainer.retry_beam == 40 + assert am_trainer.align_options["beam"] == 10 + assert am_trainer.align_options["retry_beam"] == 40 + config = LdaTrainer(identifier="lda", worker=am_trainer) + config.compute_calculated_properties() + assert config.mllt_iterations == [2, 4, 6, 16] + am_trainer.cleanup() + + +def test_load_align( + config_directory, + sick_corpus, + sick_dict, + temp_dir, + english_acoustic_model, + mono_align_config_path, +): + params = PretrainedAligner.parse_parameters(mono_align_config_path) + aligner = PretrainedAligner( + acoustic_model_path=english_acoustic_model, + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=temp_dir, + **params + ) + + assert params["beam"] == 100 + assert params["retry_beam"] == 400 + assert aligner.beam == 100 + assert aligner.retry_beam == 400 + assert aligner.align_options["beam"] == 100 + assert aligner.align_options["retry_beam"] == 400 + aligner.cleanup() path = os.path.join(config_directory, "bad_align_config.yaml") - with pytest.raises(ConfigError): - _ = align_yaml_to_config(path) - - -def test_load_basic_train(config_directory, basic_train_config): - training_config, align_config, dictioanry_config = train_yaml_to_config(basic_train_config) - assert align_config.beam == 100 - assert align_config.retry_beam == 400 - assert align_config.align_options["beam"] == 100 - assert align_config.align_options["retry_beam"] == 400 - - for trainer in training_config.training_configs: + params = PretrainedAligner.parse_parameters(path) + print(params) + aligner = PretrainedAligner( + acoustic_model_path=english_acoustic_model, + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=temp_dir, + **params + ) + assert aligner.beam == 10 + assert aligner.retry_beam == 40 + aligner.cleanup() + + +def test_load_basic_train(sick_corpus, sick_dict, temp_dir, basic_train_config_path): + params = TrainableAligner.parse_parameters(basic_train_config_path) + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=temp_dir, + **params + ) + + assert am_trainer.beam == 100 + assert am_trainer.retry_beam == 400 + assert am_trainer.align_options["beam"] == 100 + assert am_trainer.align_options["retry_beam"] == 400 + + for trainer in am_trainer.training_configs.values(): assert trainer.beam == 100 assert trainer.retry_beam == 400 assert trainer.align_options["beam"] == 100 assert trainer.align_options["retry_beam"] == 400 - - -def test_load_mono_train(config_directory, mono_train_config_path): - train, align, dictioanry_config = train_yaml_to_config(mono_train_config_path) - for t in train.training_configs: + am_trainer.cleanup() + + +def test_load_mono_train(sick_corpus, sick_dict, temp_dir, mono_train_config_path): + params = TrainableAligner.parse_parameters(mono_train_config_path) + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=temp_dir, + **params + ) + for t in am_trainer.training_configs.values(): assert not t.use_mp - assert not t.feature_config.use_mp - assert t.feature_config.use_energy - assert not align.use_mp - assert not align.feature_config.use_mp - assert align.feature_config.use_energy + assert t.use_energy + assert not am_trainer.use_mp + assert am_trainer.use_energy + am_trainer.cleanup() + +def test_load_ivector_train(sick_corpus, sick_dict, temp_dir, train_ivector_config_path): + params = TrainableIvectorExtractor.parse_parameters(train_ivector_config_path) + trainer = TrainableIvectorExtractor( + corpus_directory=sick_corpus, temporary_directory=temp_dir, **params + ) -def test_load_ivector_train(config_directory, train_ivector_config): - train, align, dictioanry_config = train_yaml_to_config(train_ivector_config) - for t in train.training_configs: + for t in trainer.training_configs.values(): assert not t.use_mp - assert not t.feature_config.use_mp - assert t.feature_config.use_energy - assert not align.use_mp - assert not align.feature_config.use_mp + assert t.use_energy + assert not trainer.use_mp + trainer.cleanup() -def test_load(config_directory): +def test_load(sick_corpus, sick_dict, temp_dir, config_directory): path = os.path.join(config_directory, "basic_train_config.yaml") - train, align, dictionary_config = train_yaml_to_config(path) - assert len(train.training_configs) == 4 - assert isinstance(train.training_configs[0], MonophoneTrainer) - assert isinstance(train.training_configs[1], TriphoneTrainer) - assert isinstance(train.training_configs[-1], SatTrainer) + params = TrainableAligner.parse_parameters(path) + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=temp_dir, + **params + ) + assert len(am_trainer.training_configs) == 4 + assert isinstance(am_trainer.training_configs["monophone"], MonophoneTrainer) + assert isinstance(am_trainer.training_configs["triphone"], TriphoneTrainer) + assert isinstance(am_trainer.training_configs[am_trainer.final_identifier], SatTrainer) path = os.path.join(config_directory, "out_of_order_config.yaml") with pytest.raises(ConfigError): - train, align, dictionary_config = train_yaml_to_config(path) + params = TrainableAligner.parse_parameters(path) + am_trainer.cleanup() -def test_multilingual_ipa(config_directory): - from montreal_forced_aligner.config.dictionary_config import DEFAULT_STRIP_DIACRITICS +def test_multilingual_ipa(sick_corpus, sick_dict, temp_dir, config_directory): path = os.path.join(config_directory, "basic_ipa_config.yaml") - train, align, dictionary_config = train_yaml_to_config(path) - assert dictionary_config.multilingual_ipa - assert set(dictionary_config.strip_diacritics) == set(DEFAULT_STRIP_DIACRITICS) - assert dictionary_config.digraphs == ["[dt][szʒʃʐʑʂɕç]", "[a][job_name][u]"] + params = TrainableAligner.parse_parameters(path) + am_trainer = TrainableAligner( + corpus_directory=sick_corpus, + dictionary_path=sick_dict, + temporary_directory=temp_dir, + **params + ) + assert am_trainer.multilingual_ipa + assert set(am_trainer.strip_diacritics) == set(DEFAULT_STRIP_DIACRITICS) + assert am_trainer.digraphs == ["[dt][szʒʃʐʑʂɕç]", "[a][job_name][u]"] + am_trainer.cleanup() diff --git a/tests/test_corpus.py b/tests/test_corpus.py index bf8e6efd..69f7219d 100644 --- a/tests/test_corpus.py +++ b/tests/test_corpus.py @@ -3,11 +3,10 @@ import pytest -from montreal_forced_aligner.config.train_config import train_yaml_to_config -from montreal_forced_aligner.corpus import Corpus +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpus from montreal_forced_aligner.corpus.classes import File, Speaker, Utterance from montreal_forced_aligner.corpus.helper import get_wav_info -from montreal_forced_aligner.dictionary import MultispeakerDictionary +from montreal_forced_aligner.corpus.text_corpus import TextCorpus from montreal_forced_aligner.exceptions import SoxError @@ -19,262 +18,313 @@ def test_mp3(mp3_test_path): pytest.skip() -def test_add(basic_corpus_dir, generated_dir): - output_directory = os.path.join(generated_dir, "basic") +def test_speaker_word_set( + multilingual_ipa_tg_corpus_dir, multispeaker_dictionary_config_path, temp_dir +): + corpus = AcousticCorpus( + corpus_directory=multilingual_ipa_tg_corpus_dir, + dictionary_path=multispeaker_dictionary_config_path, + temporary_directory=temp_dir, + ) + corpus.load_corpus() + speaker_one = corpus.speakers["speaker_one"] + assert "chad" in speaker_one.word_set() + assert speaker_one.dictionary_data.lookup("chad-like") == ["chad", "like"] + assert speaker_one.dictionary_data.oov_int not in speaker_one.dictionary_data.to_int( + "chad-like" + ) + + +def test_add(basic_corpus_dir, sick_dict_path, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - c = Corpus(basic_corpus_dir, output_directory, use_mp=True) + corpus = AcousticCorpus( + corpus_directory=basic_corpus_dir, + dictionary_path=sick_dict_path, + use_mp=True, + temporary_directory=output_directory, + ) + corpus._load_corpus() new_speaker = Speaker("new_speaker") new_file = File("new_file.wav", "new_file.txt") new_utterance = Utterance(new_speaker, new_file, text="blah blah") utterance_id = new_utterance.name - assert utterance_id not in c.utterances - c.add_utterance(new_utterance) + assert utterance_id not in corpus.utterances + corpus.add_utterance(new_utterance) - assert utterance_id in c.utterances - assert utterance_id in c.speakers["new_speaker"].utterances - assert utterance_id in c.files["new_file"].utterances - assert c.utterances[utterance_id].text == "blah blah" + assert utterance_id in corpus.utterances + assert utterance_id in corpus.speakers["new_speaker"].utterances + assert utterance_id in corpus.files["new_file"].utterances + assert corpus.utterances[utterance_id].text == "blah blah" - c.delete_utterance(utterance_id) - assert utterance_id not in c.utterances - assert "new_speaker" in c.speakers - assert "new_file" in c.files + corpus.delete_utterance(utterance_id) + assert utterance_id not in corpus.utterances + assert "new_speaker" in corpus.speakers + assert "new_file" in corpus.files -def test_basic(basic_dict_path, basic_corpus_dir, generated_dir, default_feature_config): - output_directory = os.path.join(generated_dir, "basic") +def test_basic(basic_dict_path, basic_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, output_directory) - dictionary.write() - c = Corpus(basic_corpus_dir, output_directory, use_mp=True) - c.initialize_corpus(dictionary, default_feature_config) - for speaker in c.speakers.values(): + corpus = AcousticCorpus( + corpus_directory=basic_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + for speaker in corpus.speakers.values(): data = speaker.dictionary.data() - assert speaker.dictionary.config.silence_phones == data.dictionary_config.silence_phones - assert ( - speaker.dictionary.config.multilingual_ipa == data.dictionary_config.multilingual_ipa - ) + assert speaker.dictionary.silence_phones == data.silence_phones + assert speaker.dictionary.multilingual_ipa == data.multilingual_ipa assert speaker.dictionary.words_mapping == data.words_mapping - assert speaker.dictionary.config.punctuation == data.dictionary_config.punctuation - assert speaker.dictionary.config.clitic_markers == data.dictionary_config.clitic_markers + assert speaker.dictionary.punctuation == data.punctuation + assert speaker.dictionary.clitic_markers == data.clitic_markers assert speaker.dictionary.oov_int == data.oov_int assert speaker.dictionary.words == data.words - assert c.get_feat_dim() == 39 + assert corpus.get_feat_dim() == 39 -def test_basic_txt(basic_corpus_txt_dir, basic_dict_path, generated_dir, default_feature_config): - output_directory = os.path.join(generated_dir, "basic") +def test_basic_txt(basic_corpus_txt_dir, basic_dict_path, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(generated_dir, "basic")) - dictionary.write() - c = Corpus(basic_corpus_txt_dir, output_directory, use_mp=False) - print(c.no_transcription_files) - assert len(c.no_transcription_files) == 0 - c.initialize_corpus(dictionary, default_feature_config) - assert c.get_feat_dim() == 39 + corpus = AcousticCorpus( + corpus_directory=basic_corpus_txt_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + print(corpus.no_transcription_files) + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 -def test_alignable_from_temp( - basic_corpus_txt_dir, basic_dict_path, generated_dir, default_feature_config -): - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(generated_dir, "basic")) - dictionary.write() - output_directory = os.path.join(generated_dir, "basic") + +def test_acoustic_from_temp(basic_corpus_txt_dir, basic_dict_path, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - c = Corpus(basic_corpus_txt_dir, output_directory, use_mp=False) - assert len(c.no_transcription_files) == 0 - c.initialize_corpus(dictionary, default_feature_config) - assert c.get_feat_dim() == 39 + corpus = AcousticCorpus( + corpus_directory=basic_corpus_txt_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 - c = Corpus(basic_corpus_txt_dir, output_directory, use_mp=False) - assert len(c.no_transcription_files) == 0 - c.initialize_corpus(dictionary, default_feature_config) - assert c.get_feat_dim() == 39 + new_corpus = AcousticCorpus( + corpus_directory=basic_corpus_txt_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + new_corpus.load_corpus() + assert len(new_corpus.no_transcription_files) == 0 + assert new_corpus.get_feat_dim() == 39 -def test_transcribe_from_temp( - basic_corpus_txt_dir, basic_dict_path, generated_dir, default_feature_config -): - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(generated_dir, "basic")) - dictionary.write() - output_directory = os.path.join(generated_dir, "basic") +def test_text_corpus_from_temp(basic_corpus_txt_dir, basic_dict_path, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - c = Corpus(basic_corpus_txt_dir, output_directory, use_mp=False) - c.initialize_corpus(dictionary, default_feature_config) - assert c.get_feat_dim() == 39 - - c = Corpus(basic_corpus_txt_dir, output_directory, use_mp=False) - c.initialize_corpus(dictionary, default_feature_config) - assert c.get_feat_dim() == 39 + corpus = TextCorpus( + corpus_directory=basic_corpus_txt_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.utterances) > 0 def test_extra(sick_dict, extra_corpus_dir, generated_dir): - output_directory = os.path.join(generated_dir, "extra") + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - corpus = Corpus(extra_corpus_dir, output_directory, num_jobs=2, use_mp=False) - corpus.initialize_corpus(sick_dict) + corpus = AcousticCorpus( + corpus_directory=extra_corpus_dir, + dictionary_path=sick_dict, + use_mp=False, + num_jobs=2, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 -def test_stereo(basic_dict_path, stereo_corpus_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "stereo") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(stereo_corpus_dir, temp, use_mp=False) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 +def test_stereo(basic_dict_path, stereo_corpus_dir, generated_dir): -def test_stereo_short_tg( - basic_dict_path, stereo_corpus_short_tg_dir, temp_dir, default_feature_config -): - temp = os.path.join(temp_dir, "stereo_tg") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(stereo_corpus_short_tg_dir, temp, use_mp=False) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_flac(basic_dict_path, flac_corpus_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "flac") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_corpus_dir, temp, use_mp=False) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_audio_directory(basic_dict_path, basic_split_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "audio_dir_test") + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=stereo_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=False, + num_jobs=1, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert corpus.files["michaelandsickmichael"].num_channels == 2 + + +def test_stereo_short_tg(basic_dict_path, stereo_corpus_short_tg_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=stereo_corpus_short_tg_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert corpus.files["michaelandsickmichael"].num_channels == 2 + + +def test_flac(basic_dict_path, flac_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=flac_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + + +def test_audio_directory(basic_dict_path, basic_split_dir, generated_dir): audio_dir, text_dir = basic_split_dir - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(text_dir, temp, use_mp=False, audio_directory=audio_dir) - assert len(d.no_transcription_files) == 0 - assert len(d.files) > 0 - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(text_dir, temp, use_mp=True, audio_directory=audio_dir) - assert len(d.no_transcription_files) == 0 - assert len(d.files) > 0 - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_flac_mp(basic_dict_path, flac_corpus_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "flac") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_corpus_dir, temp, use_mp=True) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_flac_tg(basic_dict_path, flac_tg_corpus_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "flac") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_tg_corpus_dir, temp, use_mp=False) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_flac_tg_mp(basic_dict_path, flac_tg_corpus_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "flac") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_tg_corpus_dir, temp, use_mp=True) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_flac_tg_transcribe(basic_dict_path, flac_tg_corpus_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "flac_tg") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_tg_corpus_dir, temp, use_mp=False) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_tg_corpus_dir, temp, use_mp=True) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_flac_transcribe( - basic_dict_path, flac_transcribe_corpus_dir, temp_dir, default_feature_config -): - temp = os.path.join(temp_dir, "flac_transcribe") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - d = Corpus(flac_transcribe_corpus_dir, temp, use_mp=True) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, os.path.join(temp, "basic")) - dictionary.write() - - d = Corpus(flac_transcribe_corpus_dir, temp, use_mp=False) - d.initialize_corpus(dictionary, default_feature_config) - assert d.get_feat_dim() == 39 - - -def test_24bit_wav(transcribe_corpus_24bit_dir, temp_dir, default_feature_config): - temp = os.path.join(temp_dir, "24bit") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - - c = Corpus(transcribe_corpus_24bit_dir, temp, use_mp=False) - c.initialize_corpus(feature_config=default_feature_config) - assert c.get_feat_dim() == 39 - assert len(c.files) == 2 - - -def test_short_segments( - basic_dict_path, shortsegments_corpus_dir, temp_dir, default_feature_config -): - temp = os.path.join(temp_dir, "short_segments") - if os.path.exists(temp): - shutil.rmtree(temp, ignore_errors=True) - dictionary = MultispeakerDictionary(basic_dict_path, temp) - dictionary.write() - corpus = Corpus(shortsegments_corpus_dir, temp, use_mp=False) - corpus.initialize_corpus(dictionary, default_feature_config) + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=text_dir, + dictionary_path=basic_dict_path, + use_mp=False, + audio_directory=audio_dir, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert len(corpus.files) > 0 + + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + corpus = AcousticCorpus( + corpus_directory=text_dir, + dictionary_path=basic_dict_path, + use_mp=True, + audio_directory=audio_dir, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert len(corpus.files) > 0 + + +def test_flac_mp(basic_dict_path, flac_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=flac_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=True, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert len(corpus.files) > 0 + + +def test_flac_tg(basic_dict_path, flac_tg_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=flac_tg_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert len(corpus.files) > 0 + + +def test_flac_tg_mp(basic_dict_path, flac_tg_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=flac_tg_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=True, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 0 + assert corpus.get_feat_dim() == 39 + assert len(corpus.files) > 0 + + +def test_24bit_wav(transcribe_corpus_24bit_dir, basic_dict_path, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=transcribe_corpus_24bit_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert len(corpus.no_transcription_files) == 2 + assert corpus.get_feat_dim() == 39 + assert len(corpus.files) > 0 + + +def test_short_segments(basic_dict_path, shortsegments_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=shortsegments_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=False, + temporary_directory=output_directory, + ) + corpus.load_corpus() assert len(corpus.utterances) == 3 assert len([x for x in corpus.utterances.values() if not x.ignored]) == 1 assert len([x for x in corpus.utterances.values() if x.features is not None]) == 1 @@ -282,130 +332,156 @@ def test_short_segments( assert len([x for x in corpus.utterances.values() if x.features is None]) == 2 -def test_speaker_groupings( - multilingual_ipa_corpus_dir, temp_dir, english_us_ipa_dictionary, default_feature_config -): - output_directory = os.path.join(temp_dir, "speaker_groupings") +def test_speaker_groupings(multilingual_ipa_corpus_dir, generated_dir, english_us_ipa_dictionary): + output_directory = os.path.join(generated_dir, "corpus_tests") if os.path.exists(output_directory): shutil.rmtree(output_directory, ignore_errors=True) - dictionary = MultispeakerDictionary(english_us_ipa_dictionary, output_directory) - dictionary.write() - c = Corpus(multilingual_ipa_corpus_dir, output_directory, use_mp=True) - c.initialize_corpus(dictionary, default_feature_config) + corpus = AcousticCorpus( + corpus_directory=multilingual_ipa_corpus_dir, + dictionary_path=english_us_ipa_dictionary, + use_mp=True, + temporary_directory=output_directory, + ) + corpus.load_corpus() speakers = os.listdir(multilingual_ipa_corpus_dir) for s in speakers: - assert any(s in x.speakers for x in c.jobs) + assert any(s in x.speakers for x in corpus.jobs) for _, _, files in os.walk(multilingual_ipa_corpus_dir): for f in files: name, ext = os.path.splitext(f) - assert name in c.files + assert name in corpus.files shutil.rmtree(output_directory, ignore_errors=True) - dictionary.write() - c = Corpus(multilingual_ipa_corpus_dir, output_directory, num_jobs=1, use_mp=True) - - c.initialize_corpus(dictionary, default_feature_config) + new_corpus = AcousticCorpus( + corpus_directory=multilingual_ipa_corpus_dir, + dictionary_path=english_us_ipa_dictionary, + num_jobs=1, + use_mp=True, + temporary_directory=output_directory, + ) + new_corpus.load_corpus() for s in speakers: - assert any(s in x.speakers for x in c.jobs) + assert any(s in x.speakers for x in new_corpus.jobs) for _, _, files in os.walk(multilingual_ipa_corpus_dir): for f in files: name, ext = os.path.splitext(f) - assert name in c.files + assert name in new_corpus.files -def test_subset( - multilingual_ipa_corpus_dir, temp_dir, english_us_ipa_dictionary, default_feature_config -): - output_directory = os.path.join(temp_dir, "large_subset") - shutil.rmtree(output_directory, ignore_errors=True) - dictionary = MultispeakerDictionary(english_us_ipa_dictionary, output_directory) - dictionary.write() - c = Corpus(multilingual_ipa_corpus_dir, output_directory, use_mp=False) - c.initialize_corpus(dictionary, default_feature_config) - sd = c.split_directory +def test_subset(multilingual_ipa_corpus_dir, generated_dir, english_us_ipa_dictionary): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) - s = c.subset_directory(5) + corpus = AcousticCorpus( + corpus_directory=multilingual_ipa_corpus_dir, + dictionary_path=english_us_ipa_dictionary, + use_mp=True, + temporary_directory=output_directory, + ) + corpus.load_corpus() + sd = corpus.split_directory + + s = corpus.subset_directory(5) assert os.path.exists(sd) assert os.path.exists(s) -def test_weird_words(weird_words_dir, temp_dir, sick_dict_path): - output_directory = os.path.join(temp_dir, "weird_words") - shutil.rmtree(output_directory, ignore_errors=True) - dictionary = MultispeakerDictionary(sick_dict_path, output_directory) - assert "i’m" not in dictionary.default_dictionary.words - assert "’m" not in dictionary.default_dictionary.words - assert dictionary.default_dictionary.words["i'm"][0]["pronunciation"] == ("ay", "m", "ih") - assert dictionary.default_dictionary.words["i'm"][1]["pronunciation"] == ("ay", "m") - assert dictionary.default_dictionary.words["'m"][0]["pronunciation"] == ("m",) - dictionary.write() - c = Corpus(weird_words_dir, output_directory, use_mp=False) - c.initialize_corpus(dictionary) - assert c.utterances["weird-words-weird-words"].oovs == [ +def test_weird_words(weird_words_dir, generated_dir, sick_dict_path): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=weird_words_dir, + dictionary_path=sick_dict_path, + use_mp=True, + temporary_directory=output_directory, + ) + corpus.load_corpus() + assert "i’m" not in corpus.default_dictionary.words + assert "’m" not in corpus.default_dictionary.words + assert corpus.default_dictionary.words["i'm"][0]["pronunciation"] == ("ay", "m", "ih") + assert corpus.default_dictionary.words["i'm"][1]["pronunciation"] == ("ay", "m") + assert corpus.default_dictionary.words["'m"][0]["pronunciation"] == ("m",) + + assert corpus.utterances["weird-words-weird-words"].oovs == { "talking-ajfish", "asds-asda", "sdasd-me", - ] + } - dictionary.set_word_set(c.word_set) + corpus.set_lexicon_word_set(corpus.corpus_word_set) for w in ["i'm", "this'm", "sdsdsds'm", "'m"]: - _ = dictionary.default_dictionary.to_int(w) - print(dictionary.oovs_found) - assert "'m" not in dictionary.oovs_found + _ = corpus.default_dictionary.to_int(w) + print(corpus.oovs_found) + assert "'m" not in corpus.oovs_found -def test_punctuated(punctuated_dir, temp_dir, sick_dict_path): - output_directory = os.path.join(temp_dir, "punctuated") - shutil.rmtree(output_directory, ignore_errors=True) - dictionary = MultispeakerDictionary(sick_dict_path, output_directory) - dictionary.write() - c = Corpus(punctuated_dir, output_directory, dictionary_config=dictionary.config, use_mp=False) - c.initialize_corpus(dictionary) +def test_punctuated(punctuated_dir, generated_dir, sick_dict_path): + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + + corpus = AcousticCorpus( + corpus_directory=punctuated_dir, + dictionary_path=sick_dict_path, + use_mp=True, + temporary_directory=output_directory, + ) + corpus.load_corpus() assert ( - c.utterances["punctuated-punctuated"].text + corpus.utterances["punctuated-punctuated"].text == "oh yes they they you know they love her and so i mean" ) def test_alternate_punctuation( - punctuated_dir, temp_dir, sick_dict_path, different_punctuation_config + punctuated_dir, generated_dir, sick_dict_path, different_punctuation_config_path ): - train_config, align_config, dictionary_config = train_yaml_to_config( - different_punctuation_config + from montreal_forced_aligner.acoustic_modeling.trainer import TrainableAligner + + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + params = AcousticCorpus.extract_relevant_parameters( + TrainableAligner.parse_parameters(different_punctuation_config_path) ) - output_directory = os.path.join(temp_dir, "punctuated") - shutil.rmtree(output_directory, ignore_errors=True) - print(dictionary_config.punctuation) - dictionary = MultispeakerDictionary(sick_dict_path, output_directory, dictionary_config) - dictionary.write() - c = Corpus( - punctuated_dir, - output_directory, - dictionary_config, - use_mp=False, + params["use_mp"] = True + corpus = AcousticCorpus( + corpus_directory=punctuated_dir, + dictionary_path=sick_dict_path, + temporary_directory=output_directory, + **params ) - c.initialize_corpus(dictionary) + corpus.load_corpus() assert ( - c.utterances["punctuated-punctuated"].text + corpus.utterances["punctuated-punctuated"].text == "oh yes, they they, you know, they love her and so i mean" ) def test_xsampa_corpus( - xsampa_corpus_dir, xsampa_dict_path, temp_dir, generated_dir, different_punctuation_config + xsampa_corpus_dir, xsampa_dict_path, generated_dir, different_punctuation_config_path ): - train_config, align_config, dictionary_config = train_yaml_to_config( - different_punctuation_config + from montreal_forced_aligner.acoustic_modeling.trainer import TrainableAligner + + output_directory = os.path.join(generated_dir, "corpus_tests") + if os.path.exists(output_directory): + shutil.rmtree(output_directory, ignore_errors=True) + params = AcousticCorpus.extract_relevant_parameters( + TrainableAligner.parse_parameters(different_punctuation_config_path) ) - output_directory = os.path.join(temp_dir, "xsampa") - shutil.rmtree(output_directory, ignore_errors=True) - print(dictionary_config.punctuation) - dictionary = MultispeakerDictionary(xsampa_dict_path, output_directory, dictionary_config) - dictionary.write() - c = Corpus(xsampa_corpus_dir, output_directory, dictionary_config, use_mp=False) - c.initialize_corpus(dictionary) + params["use_mp"] = True + corpus = AcousticCorpus( + corpus_directory=xsampa_corpus_dir, + dictionary_path=xsampa_dict_path, + temporary_directory=output_directory, + **params + ) + corpus.load_corpus() assert ( - c.utterances["xsampa-michael"].text + corpus.utterances["michael-xsampa"].text == r"@bUr\tOU {bstr\{kt {bSaIr\ Abr\utseIzi {br\@geItIN @bor\n {b3kr\Ambi {bI5s@`n Ar\g thr\Ip@5eI Ar\dvAr\k".lower() ) diff --git a/tests/test_dict.py b/tests/test_dict.py index a98c20eb..bdb4b94d 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -1,26 +1,21 @@ import os -from montreal_forced_aligner.config.dictionary_config import DictionaryConfig -from montreal_forced_aligner.config.train_config import train_yaml_to_config -from montreal_forced_aligner.dictionary import MultispeakerDictionary, PronunciationDictionary - - -def ListLines(path): - lines = [] - thefile = open(path) - text = thefile.readlines() - for line in text: - stripped = line.strip() - if stripped != "": - lines.append(stripped) - return lines +from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionary +from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionary def test_basic(basic_dict_path, generated_dir): - d = PronunciationDictionary(basic_dict_path, os.path.join(generated_dir, "basic")) - d.write() - assert set(d.config.phones) == {"sil", "sp", "spn", "phonea", "phoneb", "phonec"} - assert set(d.config.kaldi_non_silence_phones) == { + output_directory = os.path.join(generated_dir, "dictionary_tests") + dictionary = PronunciationDictionary( + dictionary_path=basic_dict_path, temporary_directory=output_directory + ) + dictionary.write() + dictionary.write(write_disambiguation=True) + + assert dictionary + assert len(dictionary) > 0 + assert set(dictionary.phones) == {"sil", "sp", "spn", "phonea", "phoneb", "phonec"} + assert set(dictionary.kaldi_non_silence_phones) == { "phonea_B", "phonea_I", "phonea_E", @@ -37,27 +32,45 @@ def test_basic(basic_dict_path, generated_dir): def test_extra_annotations(extra_annotations_path, generated_dir): - d = PronunciationDictionary(extra_annotations_path, os.path.join(generated_dir, "extra")) + output_directory = os.path.join(generated_dir, "dictionary_tests") + dictionary = MultispeakerDictionary( + dictionary_path=extra_annotations_path, temporary_directory=output_directory + ) + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary assert "{" in d.graphemes - d.write() def test_basic_noposition(basic_dict_path, generated_dir): - config = DictionaryConfig(position_dependent_phones=False) - d = PronunciationDictionary(basic_dict_path, os.path.join(generated_dir, "basic"), config) - d.write() - assert set(d.config.phones) == {"sil", "sp", "spn", "phonea", "phoneb", "phonec"} + output_directory = os.path.join(generated_dir, "dictionary_tests") + dictionary = MultispeakerDictionary( + dictionary_path=basic_dict_path, + position_dependent_phones=False, + temporary_directory=output_directory, + ) + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary + assert set(d.phones) == {"sil", "sp", "spn", "phonea", "phoneb", "phonec"} def test_frclitics(frclitics_dict_path, generated_dir): - d = PronunciationDictionary(frclitics_dict_path, os.path.join(generated_dir, "frclitics")) - d.write() + output_directory = os.path.join(generated_dir, "dictionary_tests") + dictionary = MultispeakerDictionary( + dictionary_path=frclitics_dict_path, + position_dependent_phones=False, + temporary_directory=output_directory, + ) + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary data = d.data() - assert d.silences == data.dictionary_config.silence_phones - assert d.config.multilingual_ipa == data.dictionary_config.multilingual_ipa + assert d.silences == data.silence_phones + assert d.multilingual_ipa == data.multilingual_ipa assert d.words_mapping == data.words_mapping - assert d.config.punctuation == data.dictionary_config.punctuation - assert d.config.clitic_markers == data.dictionary_config.clitic_markers + assert d.punctuation == data.punctuation + assert d.clitic_markers == data.clitic_markers assert d.oov_int == data.oov_int assert d.words == data.words assert not d.check_word("aujourd") @@ -70,11 +83,11 @@ def test_frclitics(frclitics_dict_path, generated_dir): assert d.split_clitics("m'appelle") == ["m'", "appelle"] assert d.split_clitics("m'm'appelle") == ["m'", "m'", "appelle"] assert d.split_clitics("c'est") == ["c'est"] - assert d.split_clitics("m'c'est") == ["m'", "c'est"] - assert d.split_clitics("purple-people-eater") == ["purple", "people", "eater"] + assert d.split_clitics("m'c'est") == ["m'", "c'", "est"] + assert d.split_clitics("purple-people-eater") == ["purple-people-eater"] assert d.split_clitics("m'appele") == ["m'", "appele"] assert d.split_clitics("m'ving-sic") == ["m'", "ving", "sic"] - assert d.split_clitics("flying'purple-people-eater") == ["flying'purple", "people", "eater"] + assert d.split_clitics("flying'purple-people-eater") == ["flying'purple-people-eater"] assert d.to_int("aujourd") == [d.oov_int] assert d.to_int("aujourd'hui") == [d.words_mapping["aujourd'hui"]] @@ -86,77 +99,112 @@ def test_frclitics(frclitics_dict_path, generated_dir): d.words_mapping["appelle"], ] assert d.to_int("c'est") == [d.words_mapping["c'est"]] - assert d.to_int("m'c'est") == [d.words_mapping["m'"], d.words_mapping["c'est"]] + assert d.to_int("m'c'est") == [ + d.words_mapping["m'"], + d.words_mapping["c'"], + d.words_mapping["est"], + ] assert d.to_int("purple-people-eater") == [d.oov_int] assert d.to_int("m'appele") == [d.words_mapping["m'"], d.oov_int] assert d.to_int("m'ving-sic") == [d.words_mapping["m'"], d.oov_int, d.oov_int] assert d.to_int("flying'purple-people-eater") == [d.oov_int] -def test_english_clitics(english_dictionary, generated_dir, basic_dictionary_config): - d = PronunciationDictionary( - english_dictionary, - os.path.join(generated_dir, "english_clitic_test"), - basic_dictionary_config, +def test_english_clitics(english_dictionary, generated_dir): + output_directory = os.path.join(generated_dir, "dictionary_tests") + dictionary = MultispeakerDictionary( + dictionary_path=english_dictionary, + position_dependent_phones=False, + temporary_directory=output_directory, ) - d.write() + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary assert d.split_clitics("l'orme's") == ["l'", "orme's"] assert d.to_int("l'orme's") == [d.words_mapping["l'"], d.words_mapping["orme's"]] -def test_devanagari(basic_dictionary_config): +def test_devanagari(english_dictionary, generated_dir): + output_directory = os.path.join(generated_dir, "dictionary_tests") + d = PronunciationDictionary( + dictionary_path=english_dictionary, + position_dependent_phones=False, + temporary_directory=output_directory, + ) test_cases = ["हैं", "हूं", "हौं"] for tc in test_cases: - assert tc == basic_dictionary_config.sanitize(tc) + assert tc == d.sanitize(tc) -def test_japanese(basic_dictionary_config): - assert "かぎ括弧" == basic_dictionary_config.sanitize("「かぎ括弧」") - assert "二重かぎ括弧" == basic_dictionary_config.sanitize("『二重かぎ括弧』") +def test_japanese(english_dictionary, generated_dir): + output_directory = os.path.join(generated_dir, "dictionary_tests") + d = PronunciationDictionary( + dictionary_path=english_dictionary, + position_dependent_phones=False, + temporary_directory=output_directory, + ) + assert "かぎ括弧" == d.sanitize("「かぎ括弧」") + assert "二重かぎ括弧" == d.sanitize("『二重かぎ括弧』") -def test_multilingual_ipa(basic_dictionary_config): +def test_multilingual_ipa(english_dictionary, generated_dir): + output_directory = os.path.join(generated_dir, "dictionary_tests") + dictionary = MultispeakerDictionary( + dictionary_path=english_dictionary, + position_dependent_phones=False, + multilingual_ipa=True, + temporary_directory=output_directory, + ) + + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary input_transcription = "m æ ŋ g oʊ dʒ aɪ".split() expected = tuple("m æ ŋ ɡ o ʊ d ʒ a ɪ".split()) - assert basic_dictionary_config.parse_ipa(input_transcription) == expected + assert d.parse_ipa(input_transcription) == expected input_transcription = "n ɔː ɹ job_name".split() expected = tuple("n ɔ ɹ job_name".split()) - assert basic_dictionary_config.parse_ipa(input_transcription) == expected + assert d.parse_ipa(input_transcription) == expected input_transcription = "t ʌ tʃ ə b l̩".split() expected = tuple("t ʌ t ʃ ə b l".split()) - assert basic_dictionary_config.parse_ipa(input_transcription) == expected + assert d.parse_ipa(input_transcription) == expected -def test_xsampa_dir(xsampa_dict_path, generated_dir, different_punctuation_config): +def test_xsampa_dir(xsampa_dict_path, generated_dir): + output_directory = os.path.join(generated_dir, "dictionary_tests") - train_config, align_config, dictionary_config = train_yaml_to_config( - different_punctuation_config - ) - d = PronunciationDictionary( - xsampa_dict_path, os.path.join(generated_dir, "xsampa"), dictionary_config + dictionary = MultispeakerDictionary( + dictionary_path=xsampa_dict_path, + position_dependent_phones=False, + multilingual_ipa=True, + punctuation=list(".-']["), + temporary_directory=output_directory, ) - d.write() + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary print(d.words) - assert not d.config.clitic_set + assert not d.clitic_set assert d.split_clitics(r"r\{und") == [r"r\{und"] assert d.split_clitics("{bI5s@`n") == ["{bI5s@`n"] assert d.words[r"r\{und"] -def test_multispeaker_config( - multispeaker_dictionary_config, sick_corpus, basic_dictionary_config, generated_dir -): +def test_multispeaker_config(multispeaker_dictionary_config_path, sick_corpus, generated_dir): + output_directory = os.path.join(generated_dir, "dictionary_tests") dictionary = MultispeakerDictionary( - multispeaker_dictionary_config, - os.path.join(generated_dir, "multispeaker"), - basic_dictionary_config, - word_set=sick_corpus.word_set, + dictionary_path=multispeaker_dictionary_config_path, + position_dependent_phones=False, + multilingual_ipa=True, + punctuation=list(".-']["), + temporary_directory=output_directory, ) - dictionary.write() + dictionary.dictionary_setup() + dictionary.write_lexicon_information() for d in dictionary.dictionary_mapping.values(): - assert d.silences.issubset(dictionary.config.silence_phones) - assert d.config.non_silence_phones.issubset(dictionary.config.non_silence_phones) + assert d.silence_phones.issubset(dictionary.silence_phones) + assert d.non_silence_phones.issubset(dictionary.non_silence_phones) diff --git a/tests/test_g2p.py b/tests/test_g2p.py index e876d2a4..a724ece3 100644 --- a/tests/test_g2p.py +++ b/tests/test_g2p.py @@ -2,12 +2,15 @@ import pytest -from montreal_forced_aligner.config.dictionary_config import DictionaryConfig -from montreal_forced_aligner.config.train_g2p_config import load_basic_train_g2p_config -from montreal_forced_aligner.g2p.generator import PyniniDictionaryGenerator, clean_up_word +from montreal_forced_aligner.dictionary.pronunciation import PronunciationDictionary +from montreal_forced_aligner.g2p.generator import ( + PyniniCorpusGenerator, + PyniniWordListGenerator, + clean_up_word, +) from montreal_forced_aligner.g2p.trainer import G2P_DISABLED, PyniniTrainer from montreal_forced_aligner.models import G2PModel -from montreal_forced_aligner.utils import get_mfa_version, get_pretrained_g2p_path +from montreal_forced_aligner.utils import get_mfa_version def test_clean_up_word(): @@ -17,55 +20,70 @@ def test_clean_up_word(): assert m == ["+"] -def test_check_bracketed(): +def test_check_bracketed(sick_dict): """Checks if the brackets are removed correctly and handling an empty string works""" word_set = ["uh", "(the)", "sick", "", "[a]", "{cold}", ""] expected_result = ["uh", "sick", ""] - dictionary_config = DictionaryConfig() + dictionary_config = PronunciationDictionary(dictionary_path=sick_dict) assert [x for x in word_set if not dictionary_config.check_bracketed(x)] == expected_result -def test_training(sick_dict, sick_g2p_model_path, temp_dir): +def test_training(sick_dict_path, sick_g2p_model_path, temp_dir): if G2P_DISABLED: pytest.skip("No Pynini found") - train_config, dictionary_config = load_basic_train_g2p_config() - sick_dict = sick_dict.default_dictionary - train_config.random_starts = 1 - train_config.max_iterations = 5 trainer = PyniniTrainer( - sick_dict, sick_g2p_model_path, temp_directory=temp_dir, train_config=train_config + dictionary_path=sick_dict_path, + temporary_directory=temp_dir, + random_starts=1, + num_iterations=5, + evaluate=True, ) - trainer.validate() + trainer.setup() trainer.train() + trainer.export_model(sick_g2p_model_path) model = G2PModel(sick_g2p_model_path, root_directory=temp_dir) assert model.meta["version"] == get_mfa_version() assert model.meta["architecture"] == "pynini" - assert model.meta["phones"] == sick_dict.config.non_silence_phones + assert model.meta["phones"] == trainer.non_silence_phones + assert model.meta["graphemes"] == trainer.graphemes + trainer.cleanup() -def test_generator(sick_g2p_model_path, sick_corpus, g2p_sick_output): +def test_generator(sick_g2p_model_path, sick_corpus, g2p_sick_output, temp_dir): if G2P_DISABLED: pytest.skip("No Pynini found") - model = G2PModel(sick_g2p_model_path) - dictionary_config = DictionaryConfig() - - assert not model.validate(sick_corpus.word_set) - assert model.validate( - [x for x in sick_corpus.word_set if not dictionary_config.check_bracketed(x)] + output_directory = os.path.join(temp_dir, "g2p_tests") + gen = PyniniCorpusGenerator( + g2p_model_path=sick_g2p_model_path, + corpus_directory=sick_corpus, + temporary_directory=output_directory, ) - gen = PyniniDictionaryGenerator(model, sick_corpus.word_set) - gen.output(g2p_sick_output) + + gen.setup() + assert not gen.g2p_model.validate(gen.corpus_word_set) + assert gen.g2p_model.validate([x for x in gen.corpus_word_set if not gen.check_bracketed(x)]) + + gen.export_pronunciations(g2p_sick_output) assert os.path.exists(g2p_sick_output) + gen.cleanup() -def test_generator_pretrained(english_g2p_model): +def test_generator_pretrained(english_g2p_model, temp_dir): if G2P_DISABLED: pytest.skip("No Pynini found") - model_path = get_pretrained_g2p_path("english_g2p") - model = G2PModel(model_path) words = ["petted", "petted-patted", "pedal"] - gen = PyniniDictionaryGenerator(model, words, num_pronunciations=3) - results = gen.generate() + output_directory = os.path.join(temp_dir, "g2p_tests") + word_list_path = os.path.join(output_directory, "word_list.txt") + os.makedirs(output_directory, exist_ok=True) + with open(word_list_path, "w", encoding="utf8") as f: + for w in words: + f.write(w + "\n") + gen = PyniniWordListGenerator( + g2p_model_path="english_g2p", word_list_path=word_list_path, num_pronunciations=3 + ) + gen.setup() + results = gen.generate_pronunciations() print(results) assert len(results["petted"]) == 3 + gen.cleanup() diff --git a/tests/test_gui.py b/tests/test_gui.py index 4bc815cb..7180398d 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -1,34 +1,31 @@ import os -from montreal_forced_aligner.corpus import Corpus -from montreal_forced_aligner.dictionary import MultispeakerDictionary +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpus def test_save_text_lab( basic_dict_path, basic_corpus_dir, generated_dir, - default_feature_config, - basic_dictionary_config, ): - dictionary = MultispeakerDictionary( - basic_dict_path, os.path.join(generated_dir, "basic"), basic_dictionary_config + output_directory = os.path.join(generated_dir, "corpus_tests") + corpus = AcousticCorpus( + corpus_directory=basic_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=True, + temporary_directory=output_directory, ) - dictionary.write() - output_directory = os.path.join(generated_dir, "basic") - c = Corpus(basic_corpus_dir, output_directory, basic_dictionary_config, use_mp=True) - c.initialize_corpus(dictionary) - c.files["acoustic_corpus"].save() + corpus._load_corpus() + corpus.files["acoustic_corpus"].save() -def test_flac_tg( - basic_dict_path, flac_tg_corpus_dir, temp_dir, default_feature_config, basic_dictionary_config -): - temp = os.path.join(temp_dir, "flac_tg_corpus") - dictionary = MultispeakerDictionary( - basic_dict_path, os.path.join(temp, "basic"), basic_dictionary_config +def test_flac_tg(basic_dict_path, flac_tg_corpus_dir, generated_dir): + output_directory = os.path.join(generated_dir, "corpus_tests") + corpus = AcousticCorpus( + corpus_directory=flac_tg_corpus_dir, + dictionary_path=basic_dict_path, + use_mp=True, + temporary_directory=output_directory, ) - dictionary.write() - c = Corpus(flac_tg_corpus_dir, temp, basic_dictionary_config, use_mp=False) - c.initialize_corpus(dictionary) - c.files["61-70968-0000"].save() + corpus._load_corpus() + corpus.files["61-70968-0000"].save() diff --git a/tests/test_textgrid.py b/tests/test_textgrid.py index f2ab50de..38be43ca 100644 --- a/tests/test_textgrid.py +++ b/tests/test_textgrid.py @@ -1,15 +1,20 @@ import os -from montreal_forced_aligner.dictionary import PronunciationDictionary -from montreal_forced_aligner.models import DictionaryModel +from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionary from montreal_forced_aligner.textgrid import CtmInterval -def test_mapping(english_us_ipa_dictionary, generated_dir, basic_dictionary_config): - output_directory = os.path.join(generated_dir, "ipa_temp") - d = PronunciationDictionary( - DictionaryModel(english_us_ipa_dictionary), output_directory, basic_dictionary_config +def test_mapping(english_us_ipa_dictionary, generated_dir): + output_directory = os.path.join(generated_dir, "textgrid_tests") + dictionary = MultispeakerDictionary( + dictionary_path=english_us_ipa_dictionary, + position_dependent_phones=False, + multilingual_ipa=True, + temporary_directory=output_directory, ) + dictionary.dictionary_setup() + dictionary.write_lexicon_information() + d = dictionary.default_dictionary u = "utt" cur_phones = [ CtmInterval(2.25, 2.33, "t", u), diff --git a/tox.ini b/tox.ini index 1db35516..be0c2016 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py38-{win,unix},coverage,lint,check-formatting,manifest +envlist = py39-{win,unix},coverage,lint,check-formatting,manifest minversion = 3.18.0 requires = tox-conda isolated_build = true @@ -36,12 +36,12 @@ commands = coverage xml coverage html depends = - py38-{win,unix} + py39-{win,unix} ; This env just runs `black` and fails tox if it's not formatted correctly. ; If this env fails on CI, run `tox -e format` locally in order to apply changes. [testenv:check-formatting] -basepython = python3.8 +basepython = python3.9 deps = black==21.8b0 skip_install = true commands = @@ -49,7 +49,7 @@ commands = [testenv:pkg_meta] description = check that the long description is valid -basepython = python3.8 +basepython = python3.9 skip_install = true deps = build>=0.0.4 @@ -65,7 +65,7 @@ ignore = E203 W503 [testenv:docs] -basepython = python3.8 +basepython = python3.9 skip_install=true conda_env = rtd_environment.yml commands = @@ -73,13 +73,13 @@ commands = sphinx-build -v -E -a -n -T -b html docs/source docs/build [testenv:manifest] -basepython = python3.8 +basepython = python3.9 deps = check-manifest skip_install = true commands = check-manifest [testenv:format] -basepython = python3.8 +basepython = python3.9 deps = black==21.8b0 skip_install = true commands = @@ -87,7 +87,7 @@ commands = [gh-actions] python = - 3.8: py38-unix,coverage + 3.9: py39-unix,coverage [testenv:dev] description = dev environment with all deps at {envdir}