Skip to content

Commit

Permalink
config refactor #4 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed May 11, 2021
1 parent 97bd5f9 commit dc50f5f
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 285 deletions.
47 changes: 28 additions & 19 deletions TTS/bin/compute_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,41 @@
import numpy as np
from tqdm import tqdm

from TTS.utils.config_manager import ConfigManager
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config


def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument(
"--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters."
)
parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).")
CONFIG = ConfigManager()

parser = argparse.ArgumentParser(
description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str,
help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str,
help="save path (directory and filename).")
parser.add_argument("--data_path", type=str, required=False,
help="folder including the target set of wavs overriding dataset config.")
parser = CONFIG.init_argparse(parser)
args = parser.parse_args()
CONFIG.parse_argparse(args)

# load config
CONFIG = load_config(args.config_path)
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization
CONFIG.audio["stats_path"] = None # discard pre-defined stats
CONFIG.load_config(args.config_path)
CONFIG.audio_config.signal_norm = False # do not apply earlier normalization
CONFIG.audio_config.stats_path = None # discard pre-defined stats

# load audio processor
ap = AudioProcessor(**CONFIG.audio)
ap = AudioProcessor(**CONFIG.audio_config.to_dict())

# load the meta data of target dataset
if "data_path" in CONFIG.keys():
dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True)
if args.data_path:
dataset_items = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
else:
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
dataset_items = load_meta_data(CONFIG.dataset_config)[0] # take only train data
print(f" > There are {len(dataset_items)} files.")

mel_sum = 0
Expand Down Expand Up @@ -73,14 +81,15 @@ def main():
print(f" > Avg lienar spec scale: {linear_scale.mean()}")

# set default config values for mean-var scaling
CONFIG.audio["stats_path"] = output_file_path
CONFIG.audio["signal_norm"] = True
CONFIG.audio_config.stats_path = output_file_path
CONFIG.audio_config.signal_norm = True
# remove redundant values
del CONFIG.audio["max_norm"]
del CONFIG.audio["min_level_db"]
del CONFIG.audio["symmetric_norm"]
del CONFIG.audio["clip_norm"]
stats["audio_config"] = CONFIG.audio
del CONFIG.audio_config.max_norm
del CONFIG.audio_config.min_level_db
del CONFIG.audio_config.symmetric_norm
del CONFIG.audio_config.clip_norm
breakpoint()
stats['audio_config'] = CONFIG.audio_config.to_dict()
np.save(output_file_path, stats, allow_pickle=True)
print(f" > stats saved to {output_file_path}")

Expand Down
14 changes: 9 additions & 5 deletions TTS/bin/train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
import numpy as np
import torch
from torch.utils.data import DataLoader

from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
Expand All @@ -24,8 +22,11 @@
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.config_manager import ConfigManager
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.radam import RAdam
from TTS.utils.training import (
NoamLR,
Expand Down Expand Up @@ -739,7 +740,10 @@ def main(args): # pylint: disable=redefined-outer-name

if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
c = TacotronConfig()
args = c.init_argparse(args)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, c, model_type='tacotron')

try:
main(args)
Expand Down
Loading

0 comments on commit dc50f5f

Please sign in to comment.