diff --git a/TTS/api.py b/TTS/api.py index 992fbe69e9..250ed1a0d9 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,3 +1,4 @@ +import logging import tempfile import warnings from pathlib import Path @@ -9,6 +10,8 @@ from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer +logger = logging.getLogger(__name__) + class TTS(nn.Module): """TODO: Add voice conversion and Capacitron support.""" @@ -59,7 +62,7 @@ def __init__( gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ super().__init__() - self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) + self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar) self.config = load_config(config_path) if config_path else None self.synthesizer = None self.voice_converter = None @@ -122,7 +125,7 @@ def get_models_file_path(): @staticmethod def list_models(): - return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False).list_models() + return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index faadf6901d..207b17e9c4 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -1,5 +1,6 @@ import argparse import importlib +import logging import os from argparse import RawTextHelpFormatter @@ -13,9 +14,12 @@ from TTS.tts.models import setup_model from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.io import load_checkpoint if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Extract attention masks from trained Tacotron/Tacotron2 models. diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 5b5a37df73..6795241a73 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,4 +1,5 @@ import argparse +import logging import os from argparse import RawTextHelpFormatter @@ -10,6 +11,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.managers import save_file from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_embeddings( @@ -100,6 +102,8 @@ def compute_embeddings( if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" """ diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index 3ab7ea7a3b..dc5423a691 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -3,6 +3,7 @@ import argparse import glob +import logging import os import numpy as np @@ -12,10 +13,13 @@ from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def main(): """Run preprocessing process.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") parser.add_argument("out_path", type=str, help="save path (directory and filename).") diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 60fed13932..8327851ca7 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -1,4 +1,5 @@ import argparse +import logging from argparse import RawTextHelpFormatter import torch @@ -7,6 +8,7 @@ from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_encoder_accuracy(dataset_items, encoder_manager): @@ -51,6 +53,8 @@ def compute_encoder_accuracy(dataset_items, encoder_manager): if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="""Compute the accuracy of the encoder.\n\n""" """ diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 16ad36b8dc..83f2ca21c4 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -2,6 +2,7 @@ """Extract Mel spectrograms with teacher forcing.""" import argparse +import logging import os import numpy as np @@ -17,11 +18,12 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import quantize +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger use_cuda = torch.cuda.is_available() -def setup_loader(ap, r, verbose=False): +def setup_loader(ap, r): tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( outputs_per_step=r, @@ -37,7 +39,6 @@ def setup_loader(ap, r, verbose=False): phoneme_cache_path=c.phoneme_cache_path, precompute_num_workers=0, use_noise_augment=False, - verbose=verbose, speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, ) @@ -257,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) # set r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r - own_loader = setup_loader(ap, r, verbose=True) + own_loader = setup_loader(ap, r) extract_spectrograms( own_loader, @@ -272,6 +273,8 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index f476ca5ddb..0519d43769 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -1,13 +1,17 @@ """Find all the unique characters in a dataset""" import argparse +import logging from argparse import RawTextHelpFormatter from TTS.config import load_config from TTS.tts.datasets import find_unique_chars, load_tts_samples +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Find all the unique characters or phonemes in a dataset.\n\n""" diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 48f2e7b740..d99acb9893 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -1,6 +1,7 @@ """Find all the unique characters in a dataset""" import argparse +import logging import multiprocessing from argparse import RawTextHelpFormatter @@ -9,6 +10,7 @@ from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.text.phonemizers import Gruut +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_phonemes(item): @@ -18,6 +20,8 @@ def compute_phonemes(item): def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=W0601 global c, phonemizer # pylint: disable=bad-option-value diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index a1eaf4c9a7..f6d09d6bf1 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -1,5 +1,6 @@ import argparse import glob +import logging import multiprocessing import os import pathlib @@ -7,6 +8,7 @@ import torch from tqdm import tqdm +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.vad import get_vad_model_and_utils, remove_silence torch.set_num_threads(1) @@ -75,6 +77,8 @@ def preprocess_audios(): if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" ) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index b06c93f7d1..0464cb2943 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -3,12 +3,17 @@ import argparse import contextlib +import logging import sys from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +logger = logging.getLogger(__name__) + description = """ Synthesize speech on command line. @@ -142,6 +147,8 @@ def str2bool(v): def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description=description.replace(" ```\n", ""), formatter_class=RawTextHelpFormatter, @@ -435,31 +442,37 @@ def main(): # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: - print( - " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + if synthesizer.tts_model.speaker_manager is None: + logger.info("Model only has a single speaker.") + return + logger.info( + "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) - print(synthesizer.tts_model.speaker_manager.name_to_id) + logger.info(synthesizer.tts_model.speaker_manager.name_to_id) return # query langauge ids of a multi-lingual model. if args.list_language_idxs: - print( - " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + if synthesizer.tts_model.language_manager is None: + logger.info("Monolingual model.") + return + logger.info( + "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - print(synthesizer.tts_model.language_manager.name_to_id) + logger.info(synthesizer.tts_model.language_manager.name_to_id) return # check the arguments against a multi-speaker model. if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): - print( - " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " + logger.error( + "Looks like you use a multi-speaker model. Define `--speaker_idx` to " "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." ) return # RUN THE SYNTHESIS if args.text: - print(" > Text: {}".format(args.text)) + logger.info("Text: %s", args.text) # kick it if tts_path is not None: @@ -484,8 +497,8 @@ def main(): ) # save the results - print(" > Saving output to {}".format(args.out_path)) synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) + logger.info("Saved output to %s", args.out_path) if __name__ == "__main__": diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 6a8cd7b444..c0292743bf 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import os import sys import time @@ -19,6 +20,7 @@ from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update @@ -31,7 +33,7 @@ print(" > Number of GPUs: ", num_gpus) -def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): +def setup_loader(ap: AudioProcessor, is_val: bool = False): num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch @@ -42,7 +44,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False voice_len=c.voice_len, num_utter_per_class=num_utter_per_class, num_classes_in_batch=num_classes_in_batch, - verbose=verbose, augmentation_config=c.audio_augmentation if not is_val else None, use_torch_spec=c.model_params.get("use_torch_spec", False), ) @@ -278,9 +279,9 @@ def main(args): # pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) - train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False) if c.run_eval: - eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) + eval_data_loader, _, _ = setup_loader(ap, is_val=True) else: eval_data_loader = None @@ -316,6 +317,8 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() try: diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index bdb4f6f691..6d6342a762 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field @@ -6,6 +7,7 @@ from TTS.config import load_config, register_config from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger @dataclass @@ -15,6 +17,8 @@ class TrainTTSArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # init trainer args train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 32ecd7bdc3..221ff4cff0 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field @@ -5,6 +6,7 @@ from TTS.config import load_config, register_config from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model @@ -16,6 +18,8 @@ class TrainVocoderArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # init trainer args train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index a4b10009d7..df2923952d 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -1,6 +1,7 @@ """Search a good noise schedule for WaveGrad for a given number of inference iterations""" import argparse +import logging from itertools import product as cartesian_product import numpy as np @@ -10,11 +11,14 @@ from TTS.config import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.models import setup_model if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") parser.add_argument("--config_path", type=str, help="Path to model config file.") @@ -55,7 +59,6 @@ return_segments=False, use_noise_augment=False, use_cache=False, - verbose=True, ) loader = DataLoader( dataset, diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 582b1fe9ca..81385c6c1f 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -1,3 +1,4 @@ +import logging import random import torch @@ -5,6 +6,8 @@ from TTS.encoder.utils.generic_utils import AugmentWAV +logger = logging.getLogger(__name__) + class EncoderDataset(Dataset): def __init__( @@ -15,7 +18,6 @@ def __init__( voice_len=1.6, num_classes_in_batch=64, num_utter_per_class=10, - verbose=False, augmentation_config=None, use_torch_spec=None, ): @@ -24,7 +26,6 @@ def __init__( ap (TTS.tts.utils.AudioProcessor): audio processor object. meta_data (list): list of dataset instances. seq_len (int): voice segment length in seconds. - verbose (bool): print diagnostic information. """ super().__init__() self.config = config @@ -33,7 +34,6 @@ def __init__( self.seq_len = int(voice_len * self.sample_rate) self.num_utter_per_class = num_utter_per_class self.ap = ap - self.verbose = verbose self.use_torch_spec = use_torch_spec self.classes, self.items = self.__parse_items() @@ -50,13 +50,12 @@ def __init__( if "gaussian" in augmentation_config.keys(): self.gaussian_augmentation_config = augmentation_config["gaussian"] - if self.verbose: - print("\n > DataLoader initialization") - print(f" | > Classes per Batch: {num_classes_in_batch}") - print(f" | > Number of instances : {len(self.items)}") - print(f" | > Sequence length: {self.seq_len}") - print(f" | > Num Classes: {len(self.classes)}") - print(f" | > Classes: {self.classes}") + logger.info("DataLoader initialization") + logger.info(" | Classes per batch: %d", num_classes_in_batch) + logger.info(" | Number of instances: %d", len(self.items)) + logger.info(" | Sequence length: %d", self.seq_len) + logger.info(" | Number of classes: %d", len(self.classes)) + logger.info(" | Classes: %d", self.classes) def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py index 5b5aa0fc48..2e27848c31 100644 --- a/TTS/encoder/losses.py +++ b/TTS/encoder/losses.py @@ -1,7 +1,11 @@ +import logging + import torch import torch.nn.functional as F from torch import nn +logger = logging.getLogger(__name__) + # adapted from https://github.com/cvqluu/GE2E-Loss class GE2ELoss(nn.Module): @@ -23,7 +27,7 @@ def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): self.b = nn.Parameter(torch.tensor(init_b)) self.loss_method = loss_method - print(" > Initialized Generalized End-to-End loss") + logger.info("Initialized Generalized End-to-End loss") assert self.loss_method in ["softmax", "contrast"] @@ -139,7 +143,7 @@ def __init__(self, init_w=10.0, init_b=-5.0): self.b = nn.Parameter(torch.tensor(init_b)) self.criterion = torch.nn.CrossEntropyLoss() - print(" > Initialized Angular Prototypical loss") + logger.info("Initialized Angular Prototypical loss") def forward(self, x, _label=None): """ @@ -177,7 +181,7 @@ def __init__(self, embedding_dim, n_speakers): self.criterion = torch.nn.CrossEntropyLoss() self.fc = nn.Linear(embedding_dim, n_speakers) - print("Initialised Softmax Loss") + logger.info("Initialised Softmax Loss") def forward(self, x, label=None): # reshape for compatibility @@ -212,7 +216,7 @@ def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): self.softmax = SoftmaxLoss(embedding_dim, n_speakers) self.angleproto = AngleProtoLoss(init_w, init_b) - print("Initialised SoftmaxAnglePrototypical Loss") + logger.info("Initialised SoftmaxAnglePrototypical Loss") def forward(self, x, label=None): """ diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py index 957ea3c4ca..374062463d 100644 --- a/TTS/encoder/models/base_encoder.py +++ b/TTS/encoder/models/base_encoder.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import torch import torchaudio @@ -8,6 +10,8 @@ from TTS.utils.generic_utils import set_init_dict from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class PreEmphasis(nn.Module): def __init__(self, coefficient=0.97): @@ -118,13 +122,13 @@ def load_checkpoint( state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) - print(" > Model fully restored. ") + logger.info("Model fully restored. ") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error - print(" > Partial model initialization.") + logger.info("Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"], c) self.load_state_dict(model_dict) @@ -135,7 +139,7 @@ def load_checkpoint( try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: - print(" > Criterion load ignored because of:", error) + logger.exception("Criterion load ignored because of: %s", error) # instance and load the criterion for the encoder classifier in inference time if ( diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 88ed71d3f4..495b4def5a 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -1,4 +1,5 @@ import glob +import logging import os import random @@ -8,6 +9,8 @@ from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder +logger = logging.getLogger(__name__) + class AugmentWAV(object): def __init__(self, ap, augmentation_config): @@ -38,8 +41,10 @@ def __init__(self, ap, augmentation_config): self.noise_list[noise_dir] = [] self.noise_list[noise_dir].append(wav_file) - print( - f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + logger.info( + "Using Additive Noise Augmentation: with %d audios instances from %s", + len(additive_files), + self.additive_noise_types, ) self.use_rir = False @@ -50,7 +55,7 @@ def __init__(self, ap, augmentation_config): self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) self.use_rir = True - print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files)) self.create_augmentation_global_list() diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py index 5a68c3075a..8f571dd2c7 100644 --- a/TTS/encoder/utils/prepare_voxceleb.py +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -21,13 +21,15 @@ import csv import hashlib +import logging import os import subprocess import sys import zipfile import soundfile as sf -from absl import logging + +logger = logging.getLogger(__name__) SUBSETS = { "vox1_dev_wav": [ @@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls): zip_filepath = os.path.join(directory, url.split("/")[-1]) if os.path.exists(zip_filepath): continue - logging.info("Downloading %s to %s" % (url, zip_filepath)) + logger.info("Downloading %s to %s" % (url, zip_filepath)) subprocess.call( "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), shell=True, ) statinfo = os.stat(zip_filepath) - logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) # concatenate all parts into zip files if ".zip" not in zip_filepath: @@ -118,9 +120,9 @@ def exec_cmd(cmd): try: retcode = subprocess.call(cmd, shell=True) if retcode < 0: - logging.info(f"Child was terminated by signal {retcode}") + logger.info(f"Child was terminated by signal {retcode}") except OSError as e: - logging.info(f"Execution failed: {e}") + logger.info(f"Execution failed: {e}") retcode = -999 return retcode @@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file): bool, True if success. """ cmd = f"ffmpeg -i {aac_file} {wav_file}" - logging.info(f"Decoding aac file using command line: {cmd}") + logger.info(f"Decoding aac file using command line: {cmd}") ret = exec_cmd(cmd) if ret != 0: - logging.error(f"Failed to decode aac file with retcode {ret}") - logging.error("Please check your ffmpeg installation.") + logger.error(f"Failed to decode aac file with retcode {ret}") + logger.error("Please check your ffmpeg installation.") return False return True @@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv """ - logging.info("Preprocessing audio and label for subset %s" % subset) + logger.info("Preprocessing audio and label for subset %s" % subset) source_dir = os.path.join(input_dir, subset) files = [] @@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) for wav_file in files: writer.writerow(wav_file) - logging.info("Successfully generated csv file {}".format(csv_file_path)) + logger.info("Successfully generated csv file {}".format(csv_file_path)) def processor(directory, subset, force_process): @@ -203,16 +205,16 @@ def processor(directory, subset, force_process): if not force_process and os.path.exists(subset_csv): return subset_csv - logging.info("Downloading and process the voxceleb in %s", directory) - logging.info("Preparing subset %s", subset) + logger.info("Downloading and process the voxceleb in %s", directory) + logger.info("Preparing subset %s", subset) download_and_extract(directory, subset, urls[subset]) convert_audio_and_make_label(directory, subset, directory, subset + ".csv") - logging.info("Finished downloading and processing") + logger.info("Finished downloading and processing") return subset_csv if __name__ == "__main__": - logging.set_verbosity(logging.INFO) + logging.getLogger("TTS").setLevel(logging.INFO) if len(sys.argv) != 4: print("Usage: python prepare_data.py save_directory user password") sys.exit() diff --git a/TTS/server/server.py b/TTS/server/server.py index 01bd79a137..a8f3a08817 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -2,6 +2,7 @@ import argparse import io import json +import logging import os import sys from pathlib import Path @@ -18,6 +19,9 @@ from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer +logger = logging.getLogger(__name__) +logging.getLogger("TTS").setLevel(logging.INFO) + def create_argparser(): def convert_boolean(x): @@ -200,9 +204,9 @@ def tts(): style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") style_wav = style_wav_uri_to_dict(style_wav) - print(f" > Model input: {text}") - print(f" > Speaker Idx: {speaker_idx}") - print(f" > Language Idx: {language_idx}") + logger.info("Model input: %s", text) + logger.info("Speaker idx: %s", speaker_idx) + logger.info("Language idx: %s", language_idx) wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) @@ -246,7 +250,7 @@ def mary_tts_api_process(): text = data.get("INPUT_TEXT", [""])[0] else: text = request.args.get("INPUT_TEXT", "") - print(f" > Model input: {text}") + logger.info("Model input: %s", text) wavs = synthesizer.tts(text) out = io.BytesIO() synthesizer.save_wav(wavs, out) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 4f354fa0be..f9f2cb2e37 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,3 +1,4 @@ +import logging import os import sys from collections import Counter @@ -9,6 +10,8 @@ from TTS.tts.datasets.dataset import * from TTS.tts.datasets.formatters import * +logger = logging.getLogger(__name__) + def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. @@ -122,7 +125,7 @@ def load_tts_samples( meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) - print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve()) # load evaluation split if set if eval_split: if meta_file_val: @@ -166,16 +169,15 @@ def _get_formatter_by_name(name): return getattr(thismodule, name.lower()) -def find_unique_chars(data_samples, verbose=True): +def find_unique_chars(data_samples): texts = "".join(item["text"] for item in data_samples) chars = set(texts) lower_chars = filter(lambda c: c.islower(), chars) chars_force_lower = [c.lower() for c in chars] chars_force_lower = set(chars_force_lower) - if verbose: - print(f" > Number of unique characters: {len(chars)}") - print(f" > Unique characters: {''.join(sorted(chars))}") - print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") - print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + logger.info("Number of unique characters: %d", len(chars)) + logger.info("Unique characters: %s", "".join(sorted(chars))) + logger.info("Unique lower characters: %s", "".join(sorted(lower_chars))) + logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower))) return chars_force_lower diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 257d1c3100..3886a8f8c9 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,5 +1,6 @@ import base64 import collections +import logging import os import random from typing import Dict, List, Union @@ -14,6 +15,8 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy +logger = logging.getLogger(__name__) + # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 torch.multiprocessing.set_sharing_strategy("file_system") @@ -79,7 +82,6 @@ def __init__( language_id_mapping: Dict = None, use_noise_augment: bool = False, start_by_longest: bool = False, - verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. @@ -137,8 +139,6 @@ def __init__( use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. - - verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() self.batch_group_size = batch_group_size @@ -162,7 +162,6 @@ def __init__( self.use_noise_augment = use_noise_augment self.start_by_longest = start_by_longest - self.verbose = verbose self.rescue_item_idx = 1 self.pitch_computed = False self.tokenizer = tokenizer @@ -180,8 +179,7 @@ def __init__( self.energy_dataset = EnergyDataset( self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers ) - if self.verbose: - self.print_logs() + self.print_logs() @property def lengths(self): @@ -214,11 +212,10 @@ def __getitem__(self, idx): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> DataLoader initialization") - print(f"{indent}| > Tokenizer:") + logger.info("%sDataLoader initialization", indent) + logger.info("%s| Tokenizer:", indent) self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) def load_wav(self, filename): waveform = self.ap.load_wav(filename) @@ -390,17 +387,15 @@ def preprocess_samples(self): text_lengths = [s["text_length"] for s in samples] self.samples = samples - if self.verbose: - print(" | > Preprocessing samples") - print(" | > Max text length: {}".format(np.max(text_lengths))) - print(" | > Min text length: {}".format(np.min(text_lengths))) - print(" | > Avg text length: {}".format(np.mean(text_lengths))) - print(" | ") - print(" | > Max audio length: {}".format(np.max(audio_lengths))) - print(" | > Min audio length: {}".format(np.min(audio_lengths))) - print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) - print(f" | > Num. instances discarded samples: {len(ignore_idx)}") - print(" | > Batch group size: {}.".format(self.batch_group_size)) + logger.info("Preprocessing samples") + logger.info("Max text length: {}".format(np.max(text_lengths))) + logger.info("Min text length: {}".format(np.min(text_lengths))) + logger.info("Avg text length: {}".format(np.mean(text_lengths))) + logger.info("Max audio length: {}".format(np.max(audio_lengths))) + logger.info("Min audio length: {}".format(np.min(audio_lengths))) + logger.info("Avg audio length: {}".format(np.mean(audio_lengths))) + logger.info("Num. instances discarded samples: %d", len(ignore_idx)) + logger.info("Batch group size: {}.".format(self.batch_group_size)) @staticmethod def _sort_batch(batch, text_lengths): @@ -643,7 +638,7 @@ def precompute(self, num_workers=1): We use pytorch dataloader because we are lazy. """ - print("[*] Pre-computing phonemes...") + logger.info("Pre-computing phonemes...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( @@ -665,11 +660,10 @@ def collate_fn(self, batch): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> PhonemeDataset ") - print(f"{indent}| > Tokenizer:") + logger.info("%sPhonemeDataset", indent) + logger.info("%s| Tokenizer:", indent) self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) class F0Dataset: @@ -701,14 +695,12 @@ def __init__( samples: Union[List[List], List[Dict]], ap: "AudioProcessor", audio_config=None, # pylint: disable=unused-argument - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): self.samples = samples self.ap = ap - self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 self.pad_id = 0.0 @@ -732,7 +724,7 @@ def __len__(self): return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing F0s...") + logger.info("Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing @@ -819,9 +811,8 @@ def collate_fn(self, batch): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> F0Dataset ") - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%sF0Dataset", indent) + logger.info("%s| Number of instances : %d", indent, len(self.samples)) class EnergyDataset: @@ -852,14 +843,12 @@ def __init__( self, samples: Union[List[List], List[Dict]], ap: "AudioProcessor", - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_energy=True, ): self.samples = samples self.ap = ap - self.verbose = verbose self.cache_path = cache_path self.normalize_energy = normalize_energy self.pad_id = 0.0 @@ -883,7 +872,7 @@ def __len__(self): return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing energys...") + logger.info("Pre-computing energys...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing @@ -971,6 +960,5 @@ def collate_fn(self, batch): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> energyDataset ") - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%senergyDataset") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 09fbd094e8..ff1a76e2c9 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -1,4 +1,5 @@ import csv +import logging import os import re import xml.etree.ElementTree as ET @@ -8,6 +9,8 @@ from tqdm import tqdm +logger = logging.getLogger(__name__) + ######################## # DATASETS ######################## @@ -23,7 +26,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None): num_cols = len(lines[0].split("|")) # take the first row as reference for idx, line in enumerate(lines[1:]): if len(line.split("|")) != num_cols: - print(f" > Missing column in line {idx + 1} -> {line.strip()}") + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) # load metadata with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="|") @@ -50,7 +53,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None): } ) if not_found_counter > 0: - print(f" | > [!] {not_found_counter} files not found") + logger.warning("%d files not found", not_found_counter) return items @@ -63,7 +66,7 @@ def coqui(root_path, meta_file, ignored_speakers=None): num_cols = len(lines[0].split("|")) # take the first row as reference for idx, line in enumerate(lines[1:]): if len(line.split("|")) != num_cols: - print(f" > Missing column in line {idx + 1} -> {line.strip()}") + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) # load metadata with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="|") @@ -90,7 +93,7 @@ def coqui(root_path, meta_file, ignored_speakers=None): } ) if not_found_counter > 0: - print(f" | > [!] {not_found_counter} files not found") + logger.warning("%d files not found", not_found_counter) return items @@ -173,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - print(" | > {}".format(csv_file)) + logger.info(csv_file) with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") @@ -188,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): ) else: # M-AI-Labs have some missing samples, so just print the warning - print("> File %s does not exist!" % (wav_file)) + logger.warning("File %s does not exist!", wav_file) return items @@ -253,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg text = item.text wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") if not os.path.exists(wav_file): - print(f" [!] {wav_file} in metafile does not exist. Skipping...") + logger.warning("%s in metafile does not exist. Skipping...", wav_file) continue items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -374,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar continue text = cols[1].strip() items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) - print(f" [!] {len(skipped_files)} files skipped. They don't exist...") + logger.warning("%d files skipped. They don't exist...") return items @@ -442,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} ) else: - print(f" [!] wav files don't exist - {wav_file}") + logger.warning("Wav file doesn't exist - %s", wav_file) return items diff --git a/TTS/tts/layers/bark/hubert/hubert_manager.py b/TTS/tts/layers/bark/hubert/hubert_manager.py index 4bc1992941..fd936a9157 100644 --- a/TTS/tts/layers/bark/hubert/hubert_manager.py +++ b/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -1,11 +1,14 @@ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer +import logging import os.path import shutil import urllib.request import huggingface_hub +logger = logging.getLogger(__name__) + class HubertManager: @staticmethod @@ -13,9 +16,9 @@ def make_sure_hubert_installed( download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" ): if not os.path.isfile(model_path): - print("Downloading HuBERT base model") + logger.info("Downloading HuBERT base model") urllib.request.urlretrieve(download_url, model_path) - print("Downloaded HuBERT") + logger.info("Downloaded HuBERT") return model_path return None @@ -27,9 +30,9 @@ def make_sure_tokenizer_installed( ): model_dir = os.path.dirname(model_path) if not os.path.isfile(model_path): - print("Downloading HuBERT custom tokenizer") + logger.info("Downloading HuBERT custom tokenizer") huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) shutil.move(os.path.join(model_dir, model), model_path) - print("Downloaded tokenizer") + logger.info("Downloaded tokenizer") return model_path return None diff --git a/TTS/tts/layers/bark/hubert/tokenizer.py b/TTS/tts/layers/bark/hubert/tokenizer.py index 3070241f1c..cd9579799a 100644 --- a/TTS/tts/layers/bark/hubert/tokenizer.py +++ b/TTS/tts/layers/bark/hubert/tokenizer.py @@ -5,6 +5,7 @@ """ import json +import logging import os.path from zipfile import ZipFile @@ -12,6 +13,8 @@ import torch from torch import nn, optim +logger = logging.getLogger(__name__) + class HubertTokenizer(nn.Module): def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): @@ -85,7 +88,7 @@ def train_step(self, x_train, y_train, log_loss=False): # Print loss if log_loss: - print("Loss", loss.item()) + logger.info("Loss %.3f", loss.item()) # Backward pass loss.backward() @@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep data_x, data_y = [], [] if load_model and os.path.isfile(load_model): - print("Loading model from", load_model) + logger.info("Loading model from %s", load_model) model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") else: - print("Creating new model.") + logger.info("Creating new model.") model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm save_path = os.path.join(data_path, save_path) base_save_path = ".".join(save_path.split(".")[:-1]) @@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" model_training.save(save_p) model_training.save(save_p_2) - print(f"Epoch {epoch} completed") + logger.info("Epoch %d completed", epoch) epoch += 1 diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index 74ec204281..83989f9ba4 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -1,4 +1,5 @@ ### credit: https://github.com/dunky11/voicesmith +import logging from typing import Callable, Dict, Tuple import torch @@ -20,6 +21,8 @@ from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +logger = logging.getLogger(__name__) + class AcousticModel(torch.nn.Module): def __init__( @@ -217,7 +220,7 @@ def _set_speaker_input(self, aux_input: Dict): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index de5f408c48..cd6cd0aeb2 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -1,3 +1,4 @@ +import logging import math import numpy as np @@ -10,6 +11,8 @@ from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss from TTS.utils.audio.torch_transforms import TorchSTFT +logger = logging.getLogger(__name__) + # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 @@ -132,11 +135,11 @@ def forward(self, y_hat, y, length): ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) if ssim_loss.item() > 1.0: - print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") + logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item()) ssim_loss = torch.tensor(1.0, device=ssim_loss.device) if ssim_loss.item() < 0.0: - print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") + logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item()) ssim_loss = torch.tensor(0.0, device=ssim_loss.device) return ssim_loss diff --git a/TTS/tts/layers/overflow/common_layers.py b/TTS/tts/layers/overflow/common_layers.py index b036dd1bda..9f77af293c 100644 --- a/TTS/tts/layers/overflow/common_layers.py +++ b/TTS/tts/layers/overflow/common_layers.py @@ -1,3 +1,4 @@ +import logging from typing import List, Tuple import torch @@ -8,6 +9,8 @@ from TTS.tts.layers.tacotron.common_layers import Linear from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock +logger = logging.getLogger(__name__) + class Encoder(nn.Module): r"""Neural HMM Encoder @@ -213,8 +216,8 @@ def _floor_std(self, std): original_tensor = std.clone().detach() std = torch.clamp(std, min=self.std_floor) if torch.any(original_tensor != std): - print( - "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + logger.info( + "Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" ) return std diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 7a47c35ef6..32643dfcee 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -1,12 +1,16 @@ # coding: utf-8 # adapted from https://github.com/r9y9/tacotron_pytorch +import logging + import torch from torch import nn from .attentions import init_attn from .common_layers import Prenet +logger = logging.getLogger(__name__) + class BatchNormConv1d(nn.Module): r"""A wrapper for Conv1d with BatchNorm. It sets the activation @@ -480,7 +484,7 @@ def inference(self, inputs): if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): break if t > self.max_decoder_steps: - print(" | > Decoder stopped with 'max_decoder_steps") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break return self._parse_outputs(outputs, attentions, stop_tokens) diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index c79b709972..727bf9ecfd 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -1,3 +1,5 @@ +import logging + import torch from torch import nn from torch.nn import functional as F @@ -5,6 +7,8 @@ from .attentions import init_attn from .common_layers import Linear, Prenet +logger = logging.getLogger(__name__) + # pylint: disable=no-value-for-parameter # pylint: disable=unexpected-keyword-arg @@ -356,7 +360,7 @@ def inference(self, inputs): if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: break if len(outputs) == self.max_decoder_steps: - print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break memory = self._update_memory(decoder_output) @@ -389,7 +393,7 @@ def inference_truncated(self, inputs): if stop_token > 0.7: break if len(outputs) == self.max_decoder_steps: - print(" | > Decoder stopped with 'max_decoder_steps") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break self.memory_truncated = decoder_output diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 70711ed7a4..0b8701227b 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -1,3 +1,4 @@ +import logging import os from glob import glob from typing import Dict, List @@ -10,6 +11,8 @@ from TTS.utils.audio.torch_transforms import TorchSTFT +logger = logging.getLogger(__name__) + def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) @@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str): # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. if torch.any(audio > 2) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) audio.clip_(-1, 1) @@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): for voice in voices: if voice == "random": if len(voices) > 1: - print("Cannot combine a random voice with a non-random voice. Just using a random voice.") + logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.") return None, None clip, latent = load_voice(voice, extra_voice_dirs) if latent is None: diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py index c70888df42..6a1d8ff784 100644 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -1,7 +1,10 @@ +import logging import math import torch +logger = logging.getLogger(__name__) + class NoiseScheduleVP: def __init__( @@ -1171,7 +1174,7 @@ def norm_fn(v): lambda_0 - lambda_s, ) nfe += order - print("adaptive solver nfe", nfe) + logger.debug("adaptive solver nfe %d", nfe) return x def add_noise(self, x, t, noise=None): diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py index 810a9e7f7a..898121f793 100644 --- a/TTS/tts/layers/tortoise/utils.py +++ b/TTS/tts/layers/tortoise/utils.py @@ -1,8 +1,11 @@ +import logging import os from urllib import request from tqdm import tqdm +logger = logging.getLogger(__name__) + DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = "/data/speech_synth/models/" @@ -28,10 +31,10 @@ def download_models(specific_models=None): model_path = os.path.join(MODELS_DIR, model_name) if os.path.exists(model_path): continue - print(f"Downloading {model_name} from {url}...") + logger.info("Downloading %s from %s...", model_name, url) with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) - print("Done.") + logger.info("Done.") def get_model_path(model_name, models_dir=MODELS_DIR): diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index 8598f0b47a..4a37307e74 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -1,4 +1,5 @@ import functools +import logging from math import sqrt import torch @@ -8,6 +9,8 @@ import torchaudio from einops import rearrange +logger = logging.getLogger(__name__) + def default(val, d): return val if val is not None else d @@ -79,7 +82,7 @@ def forward(self, input, return_soft_codes=False): self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) self.cluster_size = self.cluster_size * ~mask.squeeze() if torch.any(mask): - print(f"Reset {torch.sum(mask)} embedding codes.") + logger.info("Reset %d embedding codes.", torch.sum(mask)) self.codes = None self.codes_full = False diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 9add7826e6..42f64e6807 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -1,3 +1,5 @@ +import logging + import torch import torchaudio from torch import nn @@ -8,6 +10,8 @@ from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -316,7 +320,7 @@ def inference(self, c): return self.forward(c) def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: @@ -390,7 +394,7 @@ def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) + logger.warning("Layer missing in the model definition: %s", k) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers @@ -401,7 +405,7 @@ def set_init_dict(model_dict, checkpoint_state, c): pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) return model_dict @@ -579,13 +583,13 @@ def load_checkpoint( state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) - print(" > Model fully restored. ") + logger.info("Model fully restored.") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error - print(" > Partial model initialization.") + logger.info("Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"]) self.load_state_dict(model_dict) @@ -596,7 +600,7 @@ def load_checkpoint( try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: - print(" > Criterion load ignored because of:", error) + logger.exception("Criterion load ignored because of: %s", error) if use_cuda: self.cuda() diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 1a3cc47aaf..d4c3f0bbb8 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,3 +1,4 @@ +import logging import os import re import textwrap @@ -17,6 +18,8 @@ from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +logger = logging.getLogger(__name__) + def get_spacy_lang(lang): if lang == "zh": @@ -623,8 +626,10 @@ def check_input_length(self, txt, lang): lang = lang.split("-")[0] # remove the region limit = self.char_limits.get(lang, 250) if len(txt) > limit: - print( - f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." + logger.warning( + "The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.", + limit, + lang, ) def preprocess_text(self, txt, lang): diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 0a19997a47..e598232665 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -1,3 +1,4 @@ +import logging import random import sys @@ -7,6 +8,8 @@ from TTS.tts.models.xtts import load_audio +logger = logging.getLogger(__name__) + torch.set_num_threads(1) @@ -70,13 +73,13 @@ def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): random.shuffle(self.samples) # order by language self.samples = key_samples_by_col(self.samples, "language") - print(" > Sampling by language:", self.samples.keys()) + logger.info("Sampling by language: %s", self.samples.keys()) else: # for evaluation load and check samples that are corrupted to ensures the reproducibility self.check_eval_samples() def check_eval_samples(self): - print(" > Filtering invalid eval samples!!") + logger.info("Filtering invalid eval samples!!") new_samples = [] for sample in self.samples: try: @@ -92,7 +95,7 @@ def check_eval_samples(self): continue new_samples.append(sample) self.samples = new_samples - print(" > Total eval samples after filtering:", len(self.samples)) + logger.info("Total eval samples after filtering: %d", len(self.samples)) def get_text(self, text, lang): tokens = self.tokenizer.encode(text, lang) @@ -150,7 +153,7 @@ def __getitem__(self, index): # ignore samples that we already know that is not valid ones if sample_id in self.failed_samples: if self.debug_failures: - print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!") + logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"]) # call get item again to get other sample return self[1] @@ -159,7 +162,7 @@ def __getitem__(self, index): tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) except: if self.debug_failures: - print(f"error loading {sample['audio_file']} {sys.exc_info()}") + logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info()) self.failed_samples.add(sample_id) return self[1] @@ -172,8 +175,11 @@ def __getitem__(self, index): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures and wav is not None and tseq is not None: - print( - f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}" + logger.warning( + "Error loading %s: ranges are out of bounds: %d, %d", + sample["audio_file"], + wav.shape[-1], + tseq.shape[0], ) self.failed_samples.add(sample_id) return self[1] diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index daf9fc7e4f..0f161324f8 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union @@ -19,6 +20,8 @@ from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + @dataclass class GPTTrainerConfig(XttsConfig): @@ -57,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info): # return None means skip the file upload/log, returning model_info will continue with the log/upload # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size assert operation_type in ("load", "save") - # print(operation_type, model_info.__dict__) + logger.debug("%s %s", operation_type, model_info.__dict__) if "similarities.pth" in model_info.__dict__["local_model_path"]: return None @@ -91,7 +94,7 @@ def __init__(self, config: Coqpit): gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): - print("Coqui Trainer checkpoint detected! Converting it!") + logger.info("Coqui Trainer checkpoint detected! Converting it!") gpt_checkpoint = gpt_checkpoint["model"] states_keys = list(gpt_checkpoint.keys()) for key in states_keys: @@ -110,7 +113,7 @@ def __init__(self, config: Coqpit): num_new_tokens = ( self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] ) - print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") + logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens) # add new tokens to a linear layer (text_head) emb_g = gpt_checkpoint["text_embedding.weight"] @@ -137,7 +140,7 @@ def __init__(self, config: Coqpit): gpt_checkpoint["text_head.bias"] = text_head_bias self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) - print(">> GPT weights restored from:", self.args.gpt_checkpoint) + logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint) # Mel spectrogram extractor for conditioning if self.args.gpt_use_perceiver_resampler: @@ -183,7 +186,7 @@ def __init__(self, config: Coqpit): if self.args.dvae_checkpoint: dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) self.dvae.load_state_dict(dvae_checkpoint, strict=False) - print(">> DVAE weights restored from:", self.args.dvae_checkpoint) + logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) else: raise RuntimeError( "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" @@ -229,7 +232,7 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize( s_info["text"], diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py index 7d8f658160..69b8dae952 100644 --- a/TTS/tts/layers/xtts/zh_num2words.py +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -4,10 +4,13 @@ import argparse import csv +import logging import re import string import sys +logger = logging.getLogger(__name__) + # fmt: off # ================================================================================ # # basic constant @@ -923,12 +926,13 @@ def percentage2chntext(self): def normalize_nsw(raw_text): text = "^" + raw_text + "$" + logger.debug(text) # 规范化日期 pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") matchers = pattern.findall(text) if matchers: - # print('date') + logger.debug("date") for matcher in matchers: text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) @@ -936,7 +940,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") matchers = pattern.findall(text) if matchers: - # print('money') + logger.debug("money") for matcher in matchers: text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) @@ -949,14 +953,14 @@ def normalize_nsw(raw_text): pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") matchers = pattern.findall(text) if matchers: - # print('telephone') + logger.debug("telephone") for matcher in matchers: text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) # 固话 pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") matchers = pattern.findall(text) if matchers: - # print('fixed telephone') + logger.debug("fixed telephone") for matcher in matchers: text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) @@ -964,7 +968,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+/\d+)") matchers = pattern.findall(text) if matchers: - # print('fraction') + logger.debug("fraction") for matcher in matchers: text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) @@ -973,7 +977,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?%)") matchers = pattern.findall(text) if matchers: - # print('percentage') + logger.debug("percentage") for matcher in matchers: text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) @@ -981,7 +985,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) matchers = pattern.findall(text) if matchers: - # print('cardinal+quantifier') + logger.debug("cardinal+quantifier") for matcher in matchers: text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) @@ -989,7 +993,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d{4,32})") matchers = pattern.findall(text) if matchers: - # print('digit') + logger.debug("digit") for matcher in matchers: text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) @@ -997,7 +1001,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?)") matchers = pattern.findall(text) if matchers: - # print('cardinal') + logger.debug("cardinal") for matcher in matchers: text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) @@ -1005,7 +1009,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") matchers = pattern.findall(text) if matchers: - # print('particular') + logger.debug("particular") for matcher in matchers: text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) @@ -1103,7 +1107,7 @@ def __call__(self, text): if self.check_chars: for c in text: if not IN_VALID_CHARS.get(c): - print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) + logger.warning("Illegal char %s in: %s", c, text) return "" if self.remove_space: diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 2bd2e5f087..ebfa171c80 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,10 +1,13 @@ +import logging from typing import Dict, List, Union from TTS.utils.generic_utils import find_module +logger = logging.getLogger(__name__) + def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": - print(" > Using model: {}".format(config.model)) + logger.info("Using model: %s", config.model) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index f38dace235..33e1c11ab7 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -1,4 +1,5 @@ import copy +import logging from abc import abstractmethod from typing import Dict, Tuple @@ -17,6 +18,8 @@ from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler +logger = logging.getLogger(__name__) + class BaseTacotron(BaseTTS): """Base class shared by Tacotron and Tacotron2""" @@ -116,7 +119,7 @@ def load_checkpoint( self.decoder.set_r(config.r) if eval: self.eval() - print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") + logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r) assert not self.training def get_criterion(self) -> nn.Module: @@ -148,7 +151,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -302,4 +305,4 @@ def on_epoch_start(self, trainer): self.decoder.set_r(r) if trainer.config.bidirectional_decoder: trainer.model.decoder_backward.set_r(r) - print(f"\n > Number of output frames: {self.decoder.r}") + logger.info("Number of output frames: %d", self.decoder.r) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 0aa5edc647..7fbc2a3a78 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,3 +1,4 @@ +import logging import os import random from typing import Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +logger = logging.getLogger(__name__) + # pylint: skip-file @@ -105,7 +108,7 @@ def init_multispeaker(self, config: Coqpit, data: List = None): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -245,12 +248,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) weights = get_language_balancer_weights(data_items) * alpha if getattr(config, "use_speaker_weighted_sampler", False): alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_speaker_balancer_weights(data_items) * alpha else: @@ -258,7 +261,7 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): if getattr(config, "use_length_weighted_sampler", False): alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) - print(" > Using Length weighted sampler with alpha:", alpha) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_length_balancer_weights(data_items) * alpha else: @@ -330,7 +333,6 @@ def get_data_loader( phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, - verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, @@ -390,7 +392,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -429,8 +431,8 @@ def on_init_start(self, trainer): if hasattr(trainer.config, "model_args"): trainer.config.model_args.speakers_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.pth` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") + logger.info("`speakers.pth` is saved to: %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") if self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") @@ -439,8 +441,8 @@ def on_init_start(self, trainer): if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `language_ids.json` is saved to {output_path}.") - print(" > `language_ids_file` is updated in the config.json.") + logger.info("`language_ids.json` is saved to: %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") class BaseTTSE2E(BaseTTS): diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index a4aa563f48..ed318923e9 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field from itertools import chain @@ -36,6 +37,8 @@ from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +logger = logging.getLogger(__name__) + def id_to_torch(aux_id, cuda=False): if aux_id is not None: @@ -162,9 +165,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global hann_window # pylint: disable=global-statement dtype_device = str(y.dtype) + "_" + str(y.device) @@ -253,9 +256,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global mel_basis, hann_window # pylint: disable=global-statement mel_basis_key = name_mel_basis(y, n_fft, fmax) @@ -328,7 +331,6 @@ def __init__( self, ap, samples: Union[List[List], List[Dict]], - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, @@ -336,7 +338,6 @@ def __init__( super().__init__( samples=samples, ap=ap, - verbose=verbose, cache_path=cache_path, precompute_num_workers=precompute_num_workers, normalize_f0=normalize_f0, @@ -408,7 +409,7 @@ def __getitem__(self, idx): try: token_ids = self.get_token_ids(idx, item["text"]) except: - print(idx, item) + logger.exception("%s %s", idx, item) # pylint: disable=raise-missing-from raise OSError f0 = None @@ -773,7 +774,7 @@ def init_multispeaker(self, config: Coqpit): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.args.embedded_speaker_dim = self.args.speaker_embedding_channels @@ -1291,7 +1292,7 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -1405,14 +1406,14 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): - print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) - print(multi_dict) + logger.info(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha - print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) @@ -1452,7 +1453,6 @@ def get_data_loader( compute_f0=config.compute_f0, f0_cache_path=config.f0_cache_path, attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, - verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) @@ -1529,7 +1529,7 @@ def on_epoch_end(self, trainer): # pylint: disable=unused-argument @staticmethod def init_from_config( - config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None ): # pylint: disable=unused-argument """Initiate model from config diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 1d3a13d433..b108a554d5 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + @dataclass class ForwardTTSArgs(Coqpit): @@ -303,7 +306,7 @@ def init_multispeaker(self, config: Coqpit): self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index bfd1a2b618..a4ae012166 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,3 +1,4 @@ +import logging import math from typing import Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class GlowTTS(BaseTTS): """GlowTTS model. @@ -53,7 +56,7 @@ class GlowTTS(BaseTTS): >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() - >>> model = GlowTTS.init_from_config(config, verbose=False) + >>> model = GlowTTS.init_from_config(config) """ def __init__( @@ -127,7 +130,7 @@ def init_multispeaker(self, config: Coqpit): ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.embedded_speaker_dim = self.hidden_channels_enc self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @@ -479,13 +482,13 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences aux_inputs = self._get_test_aux_input() if len(test_sentences) == 0: - print(" | [!] No test sentences provided.") + logger.warning("No test sentences provided.") else: for idx, sen in enumerate(test_sentences): outputs = synthesis( @@ -540,18 +543,17 @@ def on_train_step_start(self, trainer): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index e241410872..d5bd9d1311 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -1,3 +1,4 @@ +import logging import os from typing import Dict, List, Union @@ -19,6 +20,8 @@ from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class NeuralhmmTTS(BaseTTS): """Neural HMM TTS model. @@ -235,18 +238,17 @@ def get_criterion(): return NLLLoss() @staticmethod - def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) @@ -266,14 +268,17 @@ def on_init_start(self, trainer): dataloader = trainer.get_train_dataloader( training_assets=None, samples=trainer.train_samples, verbose=False ) - print( - f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( dataloader, trainer.config.out_channels, trainer.config.state_per_phone ) - print( - f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), ) statistics = { "mean": data_mean.item(), @@ -283,8 +288,9 @@ def on_init_start(self, trainer): torch.save(statistics, trainer.config.mel_statistics_parameter_path) else: - print( - f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) statistics = torch.load(trainer.config.mel_statistics_parameter_path) data_mean, data_std, init_transition_prob = ( @@ -292,7 +298,7 @@ def on_init_start(self, trainer): statistics["std"], statistics["init_transition_prob"], ) - print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) trainer.config.flat_start_params["transition_p"] = ( init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob @@ -318,7 +324,7 @@ def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unus } # sample one item from the batch -1 will give the smalles item - print(" | > Synthesising audio from the model...") + logger.info("Synthesising audio from the model...") inference_output = self.inference( batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index 92b3c767de..0218d0452b 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -1,3 +1,4 @@ +import logging import os from typing import Dict, List, Union @@ -20,6 +21,8 @@ from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class Overflow(BaseTTS): """OverFlow TTS model. @@ -250,18 +253,17 @@ def get_criterion(): return NLLLoss() @staticmethod - def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return Overflow(new_config, ap, tokenizer, speaker_manager) @@ -282,14 +284,17 @@ def on_init_start(self, trainer): dataloader = trainer.get_train_dataloader( training_assets=None, samples=trainer.train_samples, verbose=False ) - print( - f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( dataloader, trainer.config.out_channels, trainer.config.state_per_phone ) - print( - f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), ) statistics = { "mean": data_mean.item(), @@ -299,8 +304,9 @@ def on_init_start(self, trainer): torch.save(statistics, trainer.config.mel_statistics_parameter_path) else: - print( - f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) statistics = torch.load(trainer.config.mel_statistics_parameter_path) data_mean, data_std, init_transition_prob = ( @@ -308,7 +314,7 @@ def on_init_start(self, trainer): statistics["std"], statistics["init_transition_prob"], ) - print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) trainer.config.flat_start_params["transition_p"] = ( init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob @@ -334,7 +340,7 @@ def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unus } # sample one item from the batch -1 will give the smalles item - print(" | > Synthesising audio from the model...") + logger.info("Synthesising audio from the model...") inference_output = self.inference( batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 99e0107fdf..17303c69f7 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,3 +1,4 @@ +import logging import os import random from contextlib import contextmanager @@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.models.base_tts import BaseTTS +logger = logging.getLogger(__name__) + def pad_or_truncate(t, length): """ @@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True): stop_token_indices = (codes == stop_token).nonzero() if len(stop_token_indices) == 0: if complain: - print( + logger.warning( "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " "try breaking up your input text." @@ -713,8 +716,7 @@ def inference( 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" ) self.autoregressive = self.autoregressive.to(self.device) - if verbose: - print("Generating autoregressive samples..") + logger.info("Generating autoregressive samples..") with ( self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), @@ -775,8 +777,7 @@ def inference( ) del auto_conditioning - if verbose: - print("Transforming autoregressive outputs into audio..") + logger.info("Transforming autoregressive outputs into audio..") wav_candidates = [] for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9bc743b213..2552133753 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,3 +1,4 @@ +import logging import math import os from dataclasses import dataclass, field, replace @@ -38,6 +39,8 @@ from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +logger = logging.getLogger(__name__) + ############################## # IO / Feature extraction ############################## @@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -170,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global mel_basis, hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -764,7 +767,7 @@ def init_multispeaker(self, config: Coqpit): ) self.speaker_manager.encoder.eval() - print(" > External Speaker Encoder Loaded !!") + logger.info("External Speaker Encoder Loaded !!") if ( hasattr(self.speaker_manager.encoder, "audio_config") @@ -778,7 +781,7 @@ def init_multispeaker(self, config: Coqpit): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) @@ -798,7 +801,7 @@ def init_multilingual(self, config: Coqpit): self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) if self.args.use_language_embedding and self.language_manager: - print(" > initialization of language-embedding layers.") + logger.info("Initialization of language-embedding layers.") self.num_languages = self.language_manager.num_languages self.embedded_language_dim = self.args.embedded_language_dim self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) @@ -833,7 +836,7 @@ def on_init_end(self, trainer): # pylint: disable=W0613 for key, value in after_dict.items(): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") - print(" > Duration Predictor was reinit.") + logger.info("Duration Predictor was reinit.") if self.args.reinit_text_encoder: before_dict = get_module_weights_sum(self.text_encoder) @@ -843,7 +846,7 @@ def on_init_end(self, trainer): # pylint: disable=W0613 for key, value in after_dict.items(): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") - print(" > Text Encoder was reinit.") + logger.info("Text Encoder was reinit.") def get_aux_input(self, aux_input: Dict): sid, g, lid, _ = self._set_cond_input(aux_input) @@ -1437,7 +1440,7 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -1554,14 +1557,14 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=F data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): - print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) - print(multi_dict) + logger.info(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha - print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] @@ -1609,7 +1612,6 @@ def get_data_loader( max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, - verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) @@ -1719,7 +1721,7 @@ def load_checkpoint( # handle fine-tuning from a checkpoint with additional speakers if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] - print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers) emb_g = state["model"]["emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) @@ -1776,7 +1778,7 @@ def load_fairseq_checkpoint( assert not self.training @staticmethod - def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: @@ -1799,7 +1801,7 @@ def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict] upsample_rate == effective_hop_length ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" - ap = AudioProcessor.init_from_config(config, verbose=verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 1c73c42ce9..df49cf54fd 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass @@ -15,6 +16,8 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + init_stream_support() @@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate): # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. if torch.any(audio > 10) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) # clip audio invalid values audio.clip_(-1, 1) return audio diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index e49695268d..5229af81c5 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,4 +1,5 @@ import json +import logging import os from typing import Any, Dict, List, Union @@ -10,6 +11,8 @@ from TTS.config import get_from_config_or_model_args_with_default from TTS.tts.utils.managers import EmbeddingManager +logger = logging.getLogger(__name__) + class SpeakerManager(EmbeddingManager): """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information @@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, if c.use_d_vector_file: # restore speaker manager with the embedding file if not os.path.exists(speakers_file): - print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file") + logger.warning( + "speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path + ) if not os.path.exists(c.d_vector_file): raise RuntimeError( "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" @@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, speaker_manager.load_ids_from_file(c.speakers_file) if speaker_manager.num_speakers > 0: - print( - " > Speaker manager is loaded with {} speakers: {}".format( - speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id) - ) + logger.info( + "Speaker manager is loaded with %d speakers: %s", + speaker_manager.num_speakers, + ", ".join(speaker_manager.name_to_id), ) # save file if path is defined if out_path: out_file_path = os.path.join(out_path, "speakers.json") - print(f" > Saving `speakers.json` to {out_file_path}.") + logger.info("Saving `speakers.json` to %s", out_file_path) if c.use_d_vector_file and c.d_vector_file: speaker_manager.save_embeddings_to_file(out_file_path) else: diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 37c7a7ca23..c622b93c59 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,8 +1,11 @@ +import logging from dataclasses import replace from typing import Dict from TTS.tts.configs.shared_configs import CharactersConfig +logger = logging.getLogger(__name__) + def parse_symbols(): return { @@ -305,14 +308,14 @@ def print_log(self, level: int = 0): Prints the vocabulary in a nice format. """ indent = "\t" * level - print(f"{indent}| > Characters: {self._characters}") - print(f"{indent}| > Punctuations: {self._punctuations}") - print(f"{indent}| > Pad: {self._pad}") - print(f"{indent}| > EOS: {self._eos}") - print(f"{indent}| > BOS: {self._bos}") - print(f"{indent}| > Blank: {self._blank}") - print(f"{indent}| > Vocab: {self.vocab}") - print(f"{indent}| > Num chars: {self.num_chars}") + logger.info("%s| Characters: %s", indent, self._characters) + logger.info("%s| Punctuations: %s", indent, self._punctuations) + logger.info("%s| Pad: %s", indent, self._pad) + logger.info("%s| EOS: %s", indent, self._eos) + logger.info("%s| BOS: %s", indent, self._bos) + logger.info("%s| Blank: %s", indent, self._blank) + logger.info("%s| Vocab: %s", indent, self.vocab) + logger.info("%s| Num chars: %d", indent, self.num_chars) @staticmethod def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py index 4fc7987415..5e701df458 100644 --- a/TTS/tts/utils/text/phonemizers/base.py +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -1,8 +1,11 @@ import abc +import logging from typing import List, Tuple from TTS.tts.utils.text.punctuation import Punctuation +logger = logging.getLogger(__name__) + class BasePhonemizer(abc.ABC): """Base phonemizer class @@ -136,5 +139,5 @@ def phonemize(self, text: str, separator="|", language: str = None) -> str: # p def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > phoneme language: {self.language}") - print(f"{indent}| > phoneme backend: {self.name()}") + logger.info("%s| phoneme language: %s", indent, self.language) + logger.info("%s| phoneme backend: %s", indent, self.name()) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 328e52f369..d1d2335037 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -8,6 +8,8 @@ from TTS.tts.utils.text.phonemizers.base import BasePhonemizer from TTS.tts.utils.text.punctuation import Punctuation +logger = logging.getLogger(__name__) + def is_tool(name): from shutil import which @@ -53,7 +55,7 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: "1", # UTF8 text encoding ] cmd.extend(args) - logging.debug("espeakng: executing %s", repr(cmd)) + logger.debug("espeakng: executing %s", repr(cmd)) with subprocess.Popen( cmd, @@ -189,7 +191,7 @@ def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: # compute phonemes phonemes = "" for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): - logging.debug("line: %s", repr(line)) + logger.debug("line: %s", repr(line)) ph_decoded = line.decode("utf8").strip() # espeak: # version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" @@ -227,7 +229,7 @@ def supported_languages() -> Dict: lang_code = cols[1] lang_name = cols[3] langs[lang_code] = lang_name - logging.debug("line: %s", repr(line)) + logger.debug("line: %s", repr(line)) count += 1 return langs @@ -240,7 +242,7 @@ def version(self) -> str: args = ["--version"] for line in _espeak_exe(self.backend, args, sync=True): version = line.decode("utf8").strip().split()[2] - logging.debug("line: %s", repr(line)) + logger.debug("line: %s", repr(line)) return version @classmethod diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py index 62a9c39322..1a9e98b091 100644 --- a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -1,7 +1,10 @@ +import logging from typing import Dict, List from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +logger = logging.getLogger(__name__) + class MultiPhonemizer: """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages @@ -46,8 +49,8 @@ def supported_languages(self) -> List: def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > phoneme language: {self.supported_languages()}") - print(f"{indent}| > phoneme backend: {self.name()}") + logger.info("%s| phoneme language: %s", indent, self.supported_languages()) + logger.info("%s| phoneme backend: %s", indent, self.name()) # if __name__ == "__main__": diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index b7faf86e8a..9aff7dd4bb 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -1,3 +1,4 @@ +import logging from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners @@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer from TTS.utils.generic_utils import get_import_path, import_class +logger = logging.getLogger(__name__) + class TTSTokenizer: """🐸TTS tokenizer to convert input characters to token IDs and back. @@ -73,8 +76,8 @@ def encode(self, text: str) -> List[int]: # discard but store not found characters if char not in self.not_found_characters: self.not_found_characters.append(char) - print(text) - print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") + logger.warning(text) + logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char)) return token_ids def decode(self, token_ids: List[int]) -> str: @@ -135,16 +138,16 @@ def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > add_blank: {self.add_blank}") - print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") - print(f"{indent}| > use_phonemes: {self.use_phonemes}") + logger.info("%s| add_blank: %s", indent, self.add_blank) + logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos) + logger.info("%s| use_phonemes: %s", indent, self.use_phonemes) if self.use_phonemes: - print(f"{indent}| > phonemizer:") + logger.info("%s| phonemizer:", indent) self.phonemizer.print_logs(level + 1) if len(self.not_found_characters) > 0: - print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + logger.info("%s| %d characters not found:", indent, len(self.not_found_characters)) for char in self.not_found_characters: - print(f"{indent}| > {char}") + logger.info("%s| %s", indent, char) @staticmethod def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index af88569fc3..4a8972480c 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from typing import Tuple @@ -7,6 +8,8 @@ import soundfile as sf from librosa import magphase, pyin +logger = logging.getLogger(__name__) + # For using kwargs # pylint: disable=unused-argument @@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray S_complex = np.abs(spec).astype(complex) y = istft(y=S_complex * angles, **kwargs) if not np.isfinite(y).all(): - print(" [!] Waveform is not finite everywhere. Skipping the GL.") + logger.warning("Waveform is not finite everywhere. Skipping the GL.") return np.array([0.0]) for _ in range(num_iter): angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index c53bad562e..680e29debc 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from typing import Dict, Tuple @@ -26,6 +27,8 @@ volume_norm, ) +logger = logging.getLogger(__name__) + # pylint: disable=too-many-public-methods @@ -132,10 +135,6 @@ class AudioProcessor(object): stats_path (str, optional): Path to the computed stats file. Defaults to None. - - verbose (bool, optional): - enable/disable logging. Defaults to True. - """ def __init__( @@ -172,7 +171,6 @@ def __init__( do_rms_norm=False, db_level=None, stats_path=None, - verbose=True, **_, ): # setup class attributed @@ -228,10 +226,9 @@ def __init__( self.win_length <= self.fft_size ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" members = vars(self) - if verbose: - print(" > Setting up Audio Processor...") - for key, value in members.items(): - print(" | > {}:{}".format(key, value)) + logger.info("Setting up Audio Processor...") + for key, value in members.items(): + logger.info(" | %s: %s", key, value) # create spectrogram utils self.mel_basis = build_mel_basis( sample_rate=self.sample_rate, @@ -250,10 +247,10 @@ def __init__( self.symmetric_norm = None @staticmethod - def init_from_config(config: "Coqpit", verbose=True): + def init_from_config(config: "Coqpit"): if "audio" in config: - return AudioProcessor(verbose=verbose, **config.audio) - return AudioProcessor(verbose=verbose, **config) + return AudioProcessor(**config.audio) + return AudioProcessor(**config) ### normalization ### def normalize(self, S: np.ndarray) -> np.ndarray: @@ -595,7 +592,7 @@ def load_wav(self, filename: str, sr: int = None) -> np.ndarray: try: x = self.trim_silence(x) except ValueError: - print(f" [!] File cannot be trimmed for silence - {filename}") + logger.exception("File cannot be trimmed for silence - %s", filename) if self.do_sound_norm: x = self.sound_norm(x) if self.do_rms_norm: diff --git a/TTS/utils/download.py b/TTS/utils/download.py index 37e6ed3cee..e94b1d68c8 100644 --- a/TTS/utils/download.py +++ b/TTS/utils/download.py @@ -12,6 +12,8 @@ from torch.utils.model_zoo import tqdm +logger = logging.getLogger(__name__) + def stream_url( url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True @@ -149,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo Returns: list: List of paths to extracted files even if not overwritten. """ - + logger.info("Extracting archive file...") if to_path is None: to_path = os.path.dirname(from_path) try: with tarfile.open(from_path, "r") as tar: - logging.info("Opened tar file %s.", from_path) + logger.info("Opened tar file %s.", from_path) files = [] for file_ in tar: # type: Any file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): - logging.info("%s already extracted.", file_path) + logger.info("%s already extracted.", file_path) if not overwrite: continue tar.extract(file_, to_path) @@ -172,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo try: with zipfile.ZipFile(from_path, "r") as zfile: - logging.info("Opened zip file %s.", from_path) + logger.info("Opened zip file %s.", from_path) files = zfile.namelist() for file_ in files: file_path = os.path.join(to_path, file_) if os.path.exists(file_path): - logging.info("%s already extracted.", file_path) + logger.info("%s already extracted.", file_path) if not overwrite: continue zfile.extract(file_, to_path) @@ -201,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s import kaggle # pylint: disable=import-outside-toplevel kaggle.api.authenticate() - print(f"""\nDownloading {dataset_name}...""") + logger.info("Downloading %s...", dataset_name) kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) except OSError: - print( - f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" + logger.exception( + "In order to download kaggle datasets, you need to have a kaggle api token stored in your %s", + os.path.join(expanduser("~"), ".kaggle/kaggle.json"), ) diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py index 104dc7b94e..8705873982 100644 --- a/TTS/utils/downloaders.py +++ b/TTS/utils/downloaders.py @@ -1,8 +1,11 @@ +import logging import os from typing import Optional from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive +logger = logging.getLogger(__name__) + def download_ljspeech(path: str): """Download and extract LJSpeech dataset @@ -15,7 +18,6 @@ def download_ljspeech(path: str): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"): os.makedirs(path, exist_ok=True) if subset == "all": for sub, val in subset_dict.items(): - print(f" > Downloading {sub}...") + logger.info("Downloading %s...", sub) download_url(val, path) basename = os.path.basename(val) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) - print(" > All subsets downloaded") + logger.info("All subsets downloaded") else: url = subset_dict[subset] download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -98,7 +97,6 @@ def download_thorsten_de(path: str): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index e0cd3ad85f..024d50277c 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -7,7 +7,9 @@ import subprocess import sys from pathlib import Path -from typing import Dict +from typing import Dict, Optional + +logger = logging.getLogger(__name__) # TODO: This method is duplicated in Trainer but out of date there @@ -91,7 +93,7 @@ def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) + logger.warning("Layer missing in the model finition %s", k) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers @@ -102,7 +104,7 @@ def set_init_dict(model_dict, checkpoint_state, c): pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) return model_dict @@ -123,16 +125,43 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: return kwargs -def get_timestamp(): +def get_timestamp() -> str: return datetime.now().strftime("%y%m%d-%H%M%S") -def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): +class ConsoleFormatter(logging.Formatter): + """Custom formatter that prints logging.INFO messages without the level name. + + Source: https://stackoverflow.com/a/62488520 + """ + + def format(self, record): + if record.levelno == logging.INFO: + self._style._fmt = "%(message)s" + else: + self._style._fmt = "%(levelname)s: %(message)s" + return super().format(record) + + +def setup_logger( + logger_name: str, + level: int = logging.INFO, + *, + formatter: Optional[logging.Formatter] = None, + screen: bool = False, + tofile: bool = False, + log_dir: str = "logs", + log_name: str = "log", +) -> None: lg = logging.getLogger(logger_name) - formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + if formatter is None: + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" + ) lg.setLevel(level) if tofile: - log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + Path(log_dir).mkdir(exist_ok=True, parents=True) + log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" fh = logging.FileHandler(log_file, mode="w") fh.setFormatter(formatter) lg.addHandler(fh) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index ca16183d37..d4781d54e6 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,4 +1,5 @@ import json +import logging import os import re import tarfile @@ -14,6 +15,8 @@ from TTS.config import load_config, read_json_with_comments from TTS.utils.generic_utils import get_user_data_dir +logger = logging.getLogger(__name__) + LICENSE_URLS = { "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", @@ -40,13 +43,11 @@ class ModelManager(object): models_file (str): path to .model.json file. Defaults to None. output_prefix (str): prefix to `tts` to download models. Defaults to None progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. - verbose (bool): print info. Defaults to True. """ - def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True): + def __init__(self, models_file=None, output_prefix=None, progress_bar=False): super().__init__() self.progress_bar = progress_bar - self.verbose = verbose if output_prefix is None: self.output_prefix = get_user_data_dir("tts") else: @@ -68,19 +69,16 @@ def read_models_file(self, file_path): self.models_dict = read_json_with_comments(file_path) def _list_models(self, model_type, model_count=0): - if self.verbose: - print("\n Name format: type/language/dataset/model") + logger.info("") + logger.info("Name format: type/language/dataset/model") model_list = [] for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: for model in self.models_dict[model_type][lang][dataset]: model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - output_path = os.path.join(self.output_prefix, model_full_name) - if self.verbose: - if os.path.exists(output_path): - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") - else: - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + output_path = Path(self.output_prefix) / model_full_name + downloaded = " [already downloaded]" if output_path.is_dir() else "" + logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded) model_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 return model_list @@ -99,21 +97,36 @@ def list_models(self): models_name_list.extend(model_list) return models_name_list + def log_model_details(self, model_type, lang, dataset, model): + logger.info("Model type: %s", model_type) + logger.info("Language supported: %s", lang) + logger.info("Dataset used: %s", dataset) + logger.info("Model name: %s", model) + if "description" in self.models_dict[model_type][lang][dataset][model]: + logger.info("Description: %s", self.models_dict[model_type][lang][dataset][model]["description"]) + else: + logger.info("Description: coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + logger.info( + "Default vocoder: %s", + self.models_dict[model_type][lang][dataset][model]["default_vocoder"], + ) + def model_info_by_idx(self, model_query): - """Print the description of the model from .models.json file using model_idx + """Print the description of the model from .models.json file using model_query_idx Args: - model_query (str): / + model_query (str): / """ model_name_list = [] model_type, model_query_idx = model_query.split("/") try: model_query_idx = int(model_query_idx) if model_query_idx <= 0: - print("> model_query_idx should be a positive integer!") + logger.error("model_query_idx [%d] should be a positive integer!", model_query_idx) return - except: - print("> model_query_idx should be an integer!") + except (TypeError, ValueError): + logger.error("model_query_idx [%s] should be an integer!", model_query_idx) return model_count = 0 if model_type in self.models_dict: @@ -123,22 +136,13 @@ def model_info_by_idx(self, model_query): model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 else: - print(f"> model_type {model_type} does not exist in the list.") + logger.error("Model type %s does not exist in the list.", model_type) return if model_query_idx > model_count: - print(f"model query idx exceeds the number of available models [{model_count}] ") + logger.error("model_query_idx exceeds the number of available models [%d]", model_count) else: model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") - print(f"> model type : {model_type}") - print(f"> language supported : {lang}") - print(f"> dataset used : {dataset}") - print(f"> model name : {model}") - if "description" in self.models_dict[model_type][lang][dataset][model]: - print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}") - else: - print("> description : coming soon") - if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: - print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}") + self.log_model_details(model_type, lang, dataset, model) def model_info_by_full_name(self, model_query_name): """Print the description of the model from .models.json file using model_full_name @@ -147,32 +151,19 @@ def model_info_by_full_name(self, model_query_name): model_query_name (str): Format is /// """ model_type, lang, dataset, model = model_query_name.split("/") - if model_type in self.models_dict: - if lang in self.models_dict[model_type]: - if dataset in self.models_dict[model_type][lang]: - if model in self.models_dict[model_type][lang][dataset]: - print(f"> model type : {model_type}") - print(f"> language supported : {lang}") - print(f"> dataset used : {dataset}") - print(f"> model name : {model}") - if "description" in self.models_dict[model_type][lang][dataset][model]: - print( - f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" - ) - else: - print("> description : coming soon") - if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: - print( - f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}" - ) - else: - print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.") - else: - print(f"> dataset {dataset} does not exist for {model_type}/{lang}.") - else: - print(f"> lang {lang} does not exist for {model_type}.") - else: - print(f"> model_type {model_type} does not exist in the list.") + if model_type not in self.models_dict: + logger.error("Model type %s does not exist in the list.", model_type) + return + if lang not in self.models_dict[model_type]: + logger.error("Language %s does not exist for %s.", lang, model_type) + return + if dataset not in self.models_dict[model_type][lang]: + logger.error("Dataset %s does not exist for %s/%s.", dataset, model_type, lang) + return + if model not in self.models_dict[model_type][lang][dataset]: + logger.error("Model %s does not exist for %s/%s/%s.", model, model_type, lang, dataset) + return + self.log_model_details(model_type, lang, dataset, model) def list_tts_models(self): """Print all `TTS` models and return a list of model names @@ -197,18 +188,18 @@ def list_vc_models(self): def list_langs(self): """Print all the available languages""" - print(" Name format: type/language") + logger.info("Name format: type/language") for model_type in self.models_dict: for lang in self.models_dict[model_type]: - print(f" >: {model_type}/{lang} ") + logger.info(" %s/%s", model_type, lang) def list_datasets(self): """Print all the datasets""" - print(" Name format: type/language/dataset") + logger.info("Name format: type/language/dataset") for model_type in self.models_dict: for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: - print(f" >: {model_type}/{lang}/{dataset}") + logger.info(" %s/%s/%s", model_type, lang, dataset) @staticmethod def print_model_license(model_item: Dict): @@ -218,13 +209,13 @@ def print_model_license(model_item: Dict): model_item (dict): model item in the models.json """ if "license" in model_item and model_item["license"].strip() != "": - print(f" > Model's license - {model_item['license']}") + logger.info("Model's license - %s", model_item["license"]) if model_item["license"].lower() in LICENSE_URLS: - print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()]) else: - print(" > Check https://opensource.org/licenses for more info.") + logger.info("Check https://opensource.org/licenses for more info.") else: - print(" > Model's license - No license information available") + logger.info("Model's license - No license information available") def _download_github_model(self, model_item: Dict, output_path: str): if isinstance(model_item["github_rls_url"], list): @@ -336,7 +327,7 @@ def create_dir_and_download_model(self, model_name, model_item, output_path): if not self.ask_tos(output_path): os.rmdir(output_path) raise Exception(" [!] You must agree to the terms of service to use this model.") - print(f" > Downloading model to {output_path}") + logger.info("Downloading model to %s", output_path) try: if "fairseq" in model_name: self.download_fairseq_model(model_name, output_path) @@ -346,7 +337,7 @@ def create_dir_and_download_model(self, model_name, model_item, output_path): self._download_hf_model(model_item, output_path) except requests.RequestException as e: - print(f" > Failed to download the model file to {output_path}") + logger.exception("Failed to download the model file to %s", output_path) rmtree(output_path) raise e self.print_model_license(model_item=model_item) @@ -364,7 +355,7 @@ def check_if_configs_are_equal(self, model_name, model_item, output_path): config_remote = json.load(f) if not config_local == config_remote: - print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) def download_model(self, model_name): @@ -390,12 +381,12 @@ def download_model(self, model_name): if os.path.isfile(md5sum_file): with open(md5sum_file, mode="r") as f: if not f.read() == md5sum: - print(f" > {model_name} has been updated, clearing model cache...") + logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) else: - print(f" > {model_name} is already downloaded.") + logger.info("%s is already downloaded.", model_name) else: - print(f" > {model_name} has been updated, clearing model cache...") + logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) # if the configs are different, redownload it # ToDo: we need a better way to handle it @@ -405,7 +396,7 @@ def download_model(self, model_name): except: pass else: - print(f" > {model_name} is already downloaded.") + logger.info("%s is already downloaded.", model_name) else: self.create_dir_and_download_model(model_name, model_item, output_path) @@ -544,7 +535,7 @@ def _download_zip_file(file_url, output_folder, progress_bar): z.extractall(output_folder) os.remove(temp_zip_name) # delete zip after extract except zipfile.BadZipFile: - print(f" > Error: Bad zip file - {file_url}") + logger.exception("Bad zip file - %s", file_url) raise zipfile.BadZipFile # pylint: disable=raise-missing-from # move the files to the outer path for file_path in z.namelist(): @@ -580,7 +571,7 @@ def _download_tar_file(file_url, output_folder, progress_bar): tar_names = t.getnames() os.remove(temp_tar_name) # delete tar after extract except tarfile.ReadError: - print(f" > Error: Bad tar file - {file_url}") + logger.exception("Bad tar file - %s", file_url) raise tarfile.ReadError # pylint: disable=raise-missing-from # move the files to the outer path for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 6165fb5e8a..50a7893047 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,3 +1,4 @@ +import logging import os import time from typing import List @@ -21,6 +22,8 @@ from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input +logger = logging.getLogger(__name__) + class Synthesizer(nn.Module): def __init__( @@ -218,7 +221,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N use_cuda (bool): enable/disable CUDA use. """ self.vocoder_config = load_config(model_config) - self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: @@ -294,9 +297,9 @@ def tts( if text: sens = [text] if split_sentences: - print(" > Text splitted to sentences.") sens = self.split_into_sentences(text) - print(sens) + logger.info("Text split into sentences.") + logger.info("Input: %s", sens) # handle multi-speaker if "voice_dir" in kwargs: @@ -420,7 +423,7 @@ def tts( self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + logger.info("Interpolating TTS model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -484,7 +487,7 @@ def tts( self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + logger.info("Interpolating TTS model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -500,6 +503,6 @@ def tts( # compute stats process_time = time.time() - start_time audio_time = len(wavs) / self.tts_config.audio["sample_rate"] - print(f" > Processing time: {process_time}") - print(f" > Real-time factor: {process_time / audio_time}") + logger.info("Processing time: %.3f", process_time) + logger.info("Real-time factor: %.3f", process_time / audio_time) return wavs diff --git a/TTS/utils/training.py b/TTS/utils/training.py index b51f55e92b..57885005f1 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -1,6 +1,10 @@ +import logging + import numpy as np import torch +logger = logging.getLogger(__name__) + def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): r"""Check model gradient against unexpected jumps and failures""" @@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): # compatibility with different torch versions if isinstance(grad_norm, float): if np.isinf(grad_norm): - print(" | > Gradient is INF !!") + logger.warning("Gradient is INF !!") skip_flag = True else: if torch.isinf(grad_norm): - print(" | > Gradient is INF !!") + logger.warning("Gradient is INF !!") skip_flag = True return grad_norm, skip_flag diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py index aefce2b50b..49c8dc6b66 100644 --- a/TTS/utils/vad.py +++ b/TTS/utils/vad.py @@ -1,6 +1,10 @@ +import logging + import torch import torchaudio +logger = logging.getLogger(__name__) + def read_audio(path): wav, sr = torchaudio.load(path) @@ -54,8 +58,8 @@ def remove_silence( # read ground truth wav and resample the audio for the VAD try: wav, gt_sample_rate = read_audio(audio_path) - except: - print(f"> ❗ Failed to read {audio_path}") + except Exception: + logger.exception("Failed to read %s", audio_path) return None, False # if needed, resample the audio for the VAD model @@ -80,7 +84,7 @@ def remove_silence( wav = collect_chunks(new_speech_timestamps, wav) is_speech = True else: - print(f"> The file {audio_path} probably does not have speech please check it !!") + logger.warning("The file %s probably does not have speech please check it!", audio_path) is_speech = False # save diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py index 5a09b4e53e..a498b292b7 100644 --- a/TTS/vc/models/__init__.py +++ b/TTS/vc/models/__init__.py @@ -1,7 +1,10 @@ import importlib +import logging import re from typing import Dict, List, Union +logger = logging.getLogger(__name__) + def to_camel(text): text = text.capitalize() @@ -9,7 +12,7 @@ def to_camel(text): def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": - print(" > Using model: {}".format(config.model)) + logger.info("Using model: %s", config.model) # fetch the right model implementation. if "model" in config and config["model"].lower() == "freevc": MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py index 78f1556b71..c387157f19 100644 --- a/TTS/vc/models/base_vc.py +++ b/TTS/vc/models/base_vc.py @@ -1,3 +1,4 @@ +import logging import os import random from typing import Dict, List, Tuple, Union @@ -20,6 +21,8 @@ # pylint: skip-file +logger = logging.getLogger(__name__) + class BaseVC(BaseTrainerModel): """Base `vc` class. Every new `vc` model must inherit this. @@ -93,7 +96,7 @@ def init_multispeaker(self, config: Coqpit, data: List = None): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -233,12 +236,12 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) weights = get_language_balancer_weights(data_items) * alpha if getattr(config, "use_speaker_weighted_sampler", False): alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_speaker_balancer_weights(data_items) * alpha else: @@ -246,7 +249,7 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): if getattr(config, "use_length_weighted_sampler", False): alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) - print(" > Using Length weighted sampler with alpha:", alpha) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_length_balancer_weights(data_items) * alpha else: @@ -318,7 +321,6 @@ def get_data_loader( phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, - verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=None, @@ -378,7 +380,7 @@ def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -417,8 +419,8 @@ def on_init_start(self, trainer): if hasattr(trainer.config, "model_args"): trainer.config.model_args.speakers_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.pth` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") + logger.info("`speakers.pth` is saved to %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") if self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") @@ -427,5 +429,5 @@ def on_init_start(self, trainer): if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `language_ids.json` is saved to {output_path}.") - print(" > `language_ids_file` is updated in the config.json.") + logger.info("`language_ids.json` is saved to %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 8f2a35d204..f9e691256e 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, List, Optional, Tuple, Union import librosa @@ -22,6 +23,8 @@ from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.wavlm import get_wavlm +logger = logging.getLogger(__name__) + class ResidualCouplingBlock(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): @@ -152,7 +155,7 @@ def forward(self, x, g=None): return x def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: @@ -377,7 +380,7 @@ def device(self): def load_pretrained_speaker_encoder(self): """Load pretrained speaker encoder model as mentioned in the paper.""" - print(" > Loading pretrained speaker encoder model ...") + logger.info("Loading pretrained speaker encoder model ...") self.enc_spk_ex = SpeakerEncoderEx( "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt" ) @@ -547,7 +550,7 @@ def voice_conversion(self, src, tgt): def eval_step(): ... @staticmethod - def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None): model = FreeVC(config) return model diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 1955e758ac..a3e251891a 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -1,7 +1,11 @@ +import logging + import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn +logger = logging.getLogger(__name__) + MAX_WAV_VALUE = 32768.0 @@ -39,9 +43,9 @@ def spectral_de_normalize_torch(magnitudes): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("Min value is: %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("Max value is: %.3f", torch.max(y)) global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -87,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("Min value is: %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("Max value is: %.3f", torch.max(y)) global mel_basis, hann_window dtype_device = str(y.dtype) + "_" + str(y.device) diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py index 7f811ac3ab..2636400b90 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py +++ b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -1,3 +1,4 @@ +import logging from time import perf_counter as timer from typing import List, Union @@ -17,9 +18,11 @@ sampling_rate, ) +logger = logging.getLogger(__name__) + class SpeakerEncoder(nn.Module): - def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True): + def __init__(self, weights_fpath, device: Union[str, torch.device] = None): """ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). If None, defaults to cuda if it is available on your machine, otherwise the model will @@ -50,9 +53,7 @@ def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbo self.load_state_dict(checkpoint["model_state"], strict=False) self.to(device) - - if verbose: - print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start)) + logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start) def forward(self, mels: torch.FloatTensor): """ diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py index 6edada407b..0033d22c48 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -1,3 +1,4 @@ +import logging import os import urllib.request @@ -6,6 +7,8 @@ from TTS.utils.generic_utils import get_user_data_dir from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig +logger = logging.getLogger(__name__) + model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" @@ -20,7 +23,7 @@ def get_wavlm(device="cpu"): output_path = os.path.join(output_path, "WavLM-Large.pt") if not os.path.exists(output_path): - print(f" > Downloading WavLM model to {output_path} ...") + logger.info("Downloading WavLM model to %s ...", output_path) urllib.request.urlretrieve(model_uri, output_path) checkpoint = torch.load(output_path, map_location=torch.device(device)) diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py index 871eb0d202..04462817a8 100644 --- a/TTS/vocoder/datasets/__init__.py +++ b/TTS/vocoder/datasets/__init__.py @@ -10,7 +10,7 @@ from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset -def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset: if config.model.lower() in "gan": dataset = GANDataset( ap=ap, @@ -24,7 +24,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, - verbose=verbose, ) dataset.shuffle_mapping() elif config.model.lower() == "wavegrad": @@ -39,7 +38,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: return_segments=True, use_noise_augment=False, use_cache=config.use_cache, - verbose=verbose, ) elif config.model.lower() == "wavernn": dataset = WaveRNNDataset( @@ -51,7 +49,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: mode=config.model_params.mode, mulaw=config.model_params.mulaw, is_training=not is_eval, - verbose=verbose, ) else: raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index 50c38c4deb..0806c0d496 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -28,7 +28,6 @@ def __init__( return_segments=True, use_noise_augment=False, use_cache=False, - verbose=False, ): super().__init__() self.ap = ap @@ -43,7 +42,6 @@ def __init__( self.return_segments = return_segments self.use_cache = use_cache self.use_noise_augment = use_noise_augment - self.verbose = verbose assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) @@ -109,7 +107,6 @@ def load_item(self, idx): if self.compute_feat: # compute features from wav wavpath = self.item_list[idx] - # print(wavpath) if self.use_cache and self.cache[idx] is not None: audio, mel = self.cache[idx] diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 305fe430e3..6f34bccb7c 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -28,7 +28,6 @@ def __init__( return_segments=True, use_noise_augment=False, use_cache=False, - verbose=False, ): super().__init__() self.ap = ap @@ -41,7 +40,6 @@ def __init__( self.return_segments = return_segments self.use_cache = use_cache self.use_noise_augment = use_noise_augment - self.verbose = verbose if return_segments: assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index a67c5b31a0..4c4f5c48df 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -1,9 +1,13 @@ +import logging + import numpy as np import torch from torch.utils.data import Dataset from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize +logger = logging.getLogger(__name__) + class WaveRNNDataset(Dataset): """ @@ -11,9 +15,7 @@ class WaveRNNDataset(Dataset): and converts them to acoustic features on the fly. """ - def __init__( - self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True - ): + def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True): super().__init__() self.ap = ap self.compute_feat = not isinstance(items[0], (tuple, list)) @@ -25,7 +27,6 @@ def __init__( self.mode = mode self.mulaw = mulaw self.is_training = is_training - self.verbose = verbose self.return_segments = return_segments assert self.seq_len % self.hop_len == 0 @@ -60,7 +61,7 @@ def load_item(self, index): else: min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) if audio.shape[0] < min_audio_len: - print(" [!] Instance is too short! : {}".format(wavpath)) + logger.warning("Instance is too short: %s", wavpath) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) mel = self.ap.melspectrogram(audio) @@ -80,7 +81,7 @@ def load_item(self, index): mel = np.load(feat_path.replace("/quant/", "/mel/")) if mel.shape[-1] < self.mel_len + 2 * self.pad: - print(" [!] Instance is too short! : {}".format(wavpath)) + logger.warning("Instance is too short: %s", wavpath) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index 65901617b6..7a1716f16d 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -1,8 +1,11 @@ import importlib +import logging import re from coqpit import Coqpit +logger = logging.getLogger(__name__) + def to_camel(text): text = text.capitalize() @@ -27,13 +30,13 @@ def setup_model(config: Coqpit): MyModel = getattr(MyModel, to_camel(config.model)) except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e - print(" > Vocoder Model: {}".format(config.model)) + logger.info("Vocoder model: %s", config.model) return MyModel.init_from_config(config) def setup_generator(c): """TODO: use config object as arguments""" - print(" > Generator Model: {}".format(c.generator_model)) + logger.info("Generator model: %s", c.generator_model) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) # this is to preserve the Wavernn class name (instead of Wavernn) @@ -96,7 +99,7 @@ def setup_generator(c): def setup_discriminator(c): """TODO: use config objekt as arguments""" - print(" > Discriminator Model: {}".format(c.discriminator_model)) + logger.info("Discriminator model: %s", c.discriminator_model) if "parallel_wavegan" in c.discriminator_model: MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") else: diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 19c30e983e..9b6508d8ba 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -349,7 +349,6 @@ def get_data_loader( # pylint: disable=no-self-use, unused-argument return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, - verbose=verbose, ) dataset.shuffle_mapping() sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None @@ -369,6 +368,6 @@ def get_criterion(self): return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] @staticmethod - def init_from_config(config: Coqpit, verbose=True) -> "GAN": - ap = AudioProcessor.init_from_config(config, verbose=verbose) + def init_from_config(config: Coqpit) -> "GAN": + ap = AudioProcessor.init_from_config(config) return GAN(config, ap=ap) diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 9247532259..b9561f6ff6 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,4 +1,6 @@ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import logging + import torch from torch import nn from torch.nn import Conv1d, ConvTranspose1d @@ -8,6 +10,8 @@ from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -282,7 +286,7 @@ def inference(self, c): return self.forward(c) def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index d02af75f05..211d45d91c 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -1,3 +1,4 @@ +import logging import math import torch @@ -6,6 +7,8 @@ from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +logger = logging.getLogger(__name__) + class ParallelWaveganDiscriminator(nn.Module): """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. @@ -76,7 +79,7 @@ def _apply_weight_norm(m): def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -179,7 +182,7 @@ def _apply_weight_norm(m): def remove_weight_norm(self): def _remove_weight_norm(m): try: - print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 8338d94653..96684d2a0a 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -1,3 +1,4 @@ +import logging import math import numpy as np @@ -8,6 +9,8 @@ from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.upsample import ConvUpsample +logger = logging.getLogger(__name__) + class ParallelWaveganGenerator(torch.nn.Module): """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. @@ -126,7 +129,7 @@ def inference(self, c): def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -137,7 +140,7 @@ def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.parametrizations.weight_norm(m) - # print(f"Weight norm is applied to {m}.") + logger.info("Weight norm is applied to %s", m) self.apply(_apply_weight_norm) diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 5e66b70df8..72e57a9c39 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -1,3 +1,4 @@ +import logging from typing import List import numpy as np @@ -7,6 +8,8 @@ from TTS.vocoder.layers.lvc_block import LVCBlock +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -113,7 +116,7 @@ def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) parametrize.remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -126,7 +129,7 @@ def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.parametrizations.weight_norm(m) - # print(f"Weight norm is applied to {m}.") + logger.info("Weight norm is applied to %s", m) self.apply(_apply_weight_norm) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index c1166e0914..70d9edb342 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -321,7 +321,6 @@ def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: return_segments=True, use_noise_augment=False, use_cache=config.use_cache, - verbose=verbose, ) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 7f74ba3ebf..62f6ee2d2d 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -623,7 +623,6 @@ def get_data_loader( # pylint: disable=no-self-use mode=config.model_args.mode, mulaw=config.model_args.mulaw, is_training=not is_eval, - verbose=verbose, ) sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None loader = DataLoader( diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 113240fd75..ac797d97f7 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -1,3 +1,4 @@ +import logging from typing import Dict import numpy as np @@ -7,6 +8,8 @@ from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor +logger = logging.getLogger(__name__) + def interpolate_vocoder_input(scale_factor, spec): """Interpolate spectrogram by the scale factor. @@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec): Returns: torch.tensor: interpolated spectrogram. """ - print(" > before interpolation :", spec.shape) + logger.info("Before interpolation: %s", spec.shape) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable spec = torch.nn.functional.interpolate( spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False ).squeeze(0) - print(" > after interpolation :", spec.shape) + logger.info("After interpolation: %s", spec.shape) return spec diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index e76e29283e..17992773ad 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -212,7 +212,7 @@ def test_d_vector_forward(self): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) config = VitsConfig(model_args=args) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.train() input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) @@ -357,7 +357,7 @@ def test_d_vector_inference(self): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) config = VitsConfig(model_args=args) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.eval() # batch size = 1 input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) @@ -511,7 +511,7 @@ def test_train_step_upsampling_interpolation(self): def test_train_eval_log(self): batch_size = 2 config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.run_data_dep_init = False model.train() batch = self._create_batch(config, batch_size) @@ -530,7 +530,7 @@ def test_train_eval_log(self): def test_test_run(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) @@ -540,7 +540,7 @@ def test_test_run(self): def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = VitsConfig(VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -551,20 +551,20 @@ def test_load_checkpoint(self): def test_get_criterion(self): config = VitsConfig(VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertTrue(not hasattr(model, "emb_g")) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertEqual(model.num_speakers, 2) self.assertTrue(hasattr(model, "emb_g")) @@ -576,7 +576,7 @@ def test_init_from_config(self): speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) ) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertEqual(model.num_speakers, 10) self.assertTrue(hasattr(model, "emb_g")) @@ -588,7 +588,7 @@ def test_init_from_config(self): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) ) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) diff --git a/tests/tts_tests2/test_glow_tts.py b/tests/tts_tests2/test_glow_tts.py index b93e701f19..3c7ac51556 100644 --- a/tests/tts_tests2/test_glow_tts.py +++ b/tests/tts_tests2/test_glow_tts.py @@ -132,7 +132,7 @@ def _test_forward_with_d_vector(self, batch_size): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS @@ -158,7 +158,7 @@ def _test_forward_with_speaker_id(self, batch_size): use_speaker_embedding=True, num_speakers=24, ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS @@ -206,7 +206,7 @@ def _test_inference_with_d_vector(self, batch_size): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) @@ -224,7 +224,7 @@ def _test_inference_with_speaker_ids(self, batch_size): use_speaker_embedding=True, num_speakers=24, ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) @@ -299,7 +299,7 @@ def test_train_eval_log(self): batch["d_vectors"] = None batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.run_data_dep_init = False model.train() logger = TensorboardLogger( @@ -313,7 +313,7 @@ def test_train_eval_log(self): def test_test_run(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) @@ -323,7 +323,7 @@ def test_test_run(self): def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -334,21 +334,21 @@ def test_load_checkpoint(self): def test_get_criterion(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) config = GlowTTSConfig(num_chars=32, num_speakers=2) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(not hasattr(model, "emb_g")) config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(hasattr(model, "emb_g")) @@ -358,7 +358,7 @@ def test_init_from_config(self): use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(hasattr(model, "emb_g")) @@ -368,7 +368,7 @@ def test_init_from_config(self): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.c_in_channels == config.d_vector_dim)