Skip to content

Commit

Permalink
Add XTTS training unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Oct 21, 2023
1 parent 1f92741 commit affaf11
Show file tree
Hide file tree
Showing 5 changed files with 12,858 additions and 17 deletions.
5 changes: 4 additions & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def format_batch_on_device(self, batch):
dvae_wav = batch["wav"]
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
codes = self.dvae.get_codebook_indices(dvae_mel_spec)

batch["audio_codes"] = codes
# delete useless batch tensors
del batch["padded_text"]
Expand Down Expand Up @@ -454,7 +455,9 @@ def load_checkpoint(
target_options={"anon": True},
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]

state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path)

# load the model weights
self.xtts.load_state_dict(state, strict=strict)

Expand Down
37 changes: 21 additions & 16 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def inference(
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)

text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = self.gpt(
text_tokens,
Expand Down Expand Up @@ -788,6 +789,25 @@ def eval(self): # pylint: disable=redefined-builtin
self.gpt.init_gpt_for_inference()
super().eval()

def get_compatible_checkpoint_state_dict(self, model_path):
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key

# remove unused keys
if key.split(".")[0] in ignore_keys:
del checkpoint[key]

return checkpoint

def load_checkpoint(
self,
config,
Expand Down Expand Up @@ -821,22 +841,7 @@ def load_checkpoint(

self.init_models()

checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
coqui_trainer_checkpoint = True
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key

# remove unused keys
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
checkpoint = self.get_compatible_checkpoint_state_dict(model_path)

# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
try:
Expand Down
1 change: 1 addition & 0 deletions recipes/ljspeech/xtts_v1/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
LANGUAGE = config_dataset.language


def main():
# init args and config
model_args = GPTArgs(
Expand Down
Loading

0 comments on commit affaf11

Please sign in to comment.