Skip to content

Commit

Permalink
Update recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Oct 21, 2023
1 parent affaf11 commit 1cf1eba
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -12,13 +10,10 @@
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler

from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
Expand Down Expand Up @@ -456,7 +451,7 @@ def load_checkpoint(
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""

state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path)
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)

# load the model weights
self.xtts.load_state_dict(state, strict=strict)
Expand Down

0 comments on commit 1cf1eba

Please sign in to comment.