-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
OGUN Sewade Olaolu
committed
Mar 15, 2024
1 parent
581cf50
commit 34e8942
Showing
4 changed files
with
193 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import os | ||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
|
||
from trainer import Trainer, TrainerArgs | ||
|
||
from TTS.config.shared_configs import BaseDatasetConfig | ||
from TTS.tts.datasets import load_tts_samples | ||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig | ||
from TTS.utils.manage import ModelManager | ||
|
||
# Logging parameters | ||
RUN_NAME = "GPT_XTTS_v2.0_AfroTTS_FT" | ||
PROJECT_NAME = "XTTS_trainer" | ||
DASHBOARD_LOGGER = "tensorboard" | ||
LOGGER_URI = None | ||
|
||
# Set here the path that the checkpoints will be saved. Default: ./run/training/ | ||
OUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run", "training") | ||
|
||
# Training Parameters | ||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False | ||
START_WITH_EVAL = False # if True it will star with evaluation | ||
BATCH_SIZE = 2 # set here the batch size | ||
GRAD_ACUMM_STEPS = 126 # set here the grad accumulation steps | ||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly. | ||
|
||
afrotts_dir = "AfriSpeech-TTS-D" # add path to afrotts data here | ||
|
||
# Define here the dataset that you want to use for the fine-tuning on. | ||
config_dataset = BaseDatasetConfig( | ||
formatter="afrotts", | ||
dataset_name="afrotts", | ||
path=afrotts_dir, | ||
meta_file_train=os.path.join(afrotts_dir, "data/afritts-train-clean-upsamp.csv") #afritts-train-clean-upsamp.csv | ||
meta_file_val=os.path.join(afrotts_dir,"data/afritts-dev-clean.csv"), | ||
language="en", | ||
) | ||
|
||
# Add here the configs of the datasets | ||
DATASETS_CONFIG_LIST = [config_dataset] | ||
|
||
# Define the path where XTTS v2.0.1 files will be downloaded | ||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") | ||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) | ||
|
||
# Set the path to the downloaded files | ||
DVAE_CHECKPOINT = "coqui-ai-TTS/recipes/ljspeech/xtts_v2/run/training/XTTS_v2.0_original_model_files/dvae.pth" | ||
MEL_NORM_FILE = "coqui-ai-TTS/recipes/ljspeech/xtts_v2/run/training/XTTS_v2.0_original_model_files/mel_stats.pth" | ||
|
||
# download DVAE files if needed | ||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): | ||
print(" > Downloading DVAE files!") | ||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) | ||
|
||
|
||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. | ||
TOKENIZER_FILE = "coqui-ai-TTS/recipes/ljspeech/xtts_v2/run/training/XTTS_v2.0_original_model_files/vocab.json" | ||
XTTS_CHECKPOINT = "coqui-ai-TTS/recipes/ljspeech/xtts_v2/run/training/GPT_XTTS_v2.0_AfroTTS_FT-March-06-2024_06+36AM-581cf506/checkpoint_135000.pth" | ||
|
||
|
||
# download XTTS v2.0 files if needed | ||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): | ||
print(" > Downloading XTTS v2.0 files!") | ||
ModelManager._download_model_files( | ||
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True | ||
) | ||
|
||
|
||
# Training sentences generations | ||
SPEAKER_REFERENCE = [ | ||
"/AfriSpeech-TTS/train/defc5e03-926c-4e0b-a639-c821e5e7db89/14f64f13c57f9a64a2a1521253934a0b_KYA8MaKS.wav" # speaker reference to be used in training test sentences | ||
] | ||
LANGUAGE = config_dataset.language | ||
|
||
def main(): | ||
# init args and config | ||
model_args = GPTArgs( | ||
max_conditioning_length=132300, # 6 secs | ||
min_conditioning_length=66150, # 3 secs | ||
debug_loading_failures=False, | ||
max_wav_length=255995, # ~11.6 seconds 661500, #~ 30 seconds # | ||
max_text_length=300, | ||
mel_norm_file=MEL_NORM_FILE, | ||
dvae_checkpoint=DVAE_CHECKPOINT, | ||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune | ||
tokenizer_file=TOKENIZER_FILE, | ||
gpt_num_audio_tokens=1026, | ||
gpt_start_audio_token=1024, | ||
gpt_stop_audio_token=1025, | ||
gpt_use_masking_gt_prompt_approach=True, | ||
gpt_use_perceiver_resampler=True, | ||
) | ||
# define audio config | ||
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) | ||
# training parameters config | ||
config = GPTTrainerConfig( | ||
output_path=OUT_PATH, | ||
model_args=model_args, | ||
run_name=RUN_NAME, | ||
project_name=PROJECT_NAME, | ||
run_description=""" | ||
GPT XTTS training | ||
""", | ||
dashboard_logger=DASHBOARD_LOGGER, | ||
logger_uri=LOGGER_URI, | ||
audio=audio_config, | ||
batch_size=BATCH_SIZE, | ||
batch_group_size=64, | ||
eval_batch_size=BATCH_SIZE, | ||
num_loader_workers=8, | ||
eval_split_max_size=256, | ||
print_step=1000, | ||
plot_step=1000, | ||
log_model_step=1000, | ||
save_step=1000, | ||
save_n_checkpoints=3, | ||
save_checkpoints=True, | ||
# target_loss="loss", | ||
print_eval=True, | ||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. | ||
optimizer="AdamW", | ||
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, | ||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, | ||
lr=5e-06, # learning rate | ||
lr_scheduler="MultiStepLR", | ||
# it was adjusted accordly for the new step scheme | ||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, | ||
test_sentences=[ | ||
{ | ||
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", | ||
"speaker_wav": SPEAKER_REFERENCE, | ||
"language": LANGUAGE, | ||
}, | ||
{ | ||
"text": "This cake is great. It's so delicious and moist.", | ||
"speaker_wav": SPEAKER_REFERENCE, | ||
"language": LANGUAGE, | ||
}, | ||
], | ||
) | ||
|
||
# init the model from config | ||
model = GPTTrainer.init_from_config(config) | ||
|
||
# load training samples | ||
train_samples, eval_samples = load_tts_samples( | ||
DATASETS_CONFIG_LIST, | ||
eval_split=True, | ||
eval_split_max_size=config.eval_split_max_size, | ||
eval_split_size=config.eval_split_size, | ||
) | ||
# init the trainer and 🚀 | ||
trainer = Trainer( | ||
TrainerArgs( | ||
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter | ||
skip_train_epoch=False, | ||
start_with_eval=START_WITH_EVAL, | ||
grad_accum_steps=GRAD_ACUMM_STEPS, | ||
), | ||
config, | ||
output_path=OUT_PATH, | ||
model=model, | ||
train_samples=train_samples, | ||
eval_samples=eval_samples, | ||
) | ||
trainer.fit() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |