diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index e7d9e772..c298c0e8 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -8,7 +8,7 @@ import lightning.pytorch as pl import torch -from lightning.pytorch.accelerators import TPUAccelerator +from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.tuner import Tuner from torch.cuda.amp import autocast @@ -175,10 +175,10 @@ def on_train_start(self) -> None: global_step = total_batch_idx * self.optimizers_count self.set_global_step(global_step) - # check if using tpu - if isinstance(self.trainer.accelerator, TPUAccelerator): + # check if using tpu or mps + if isinstance(self.trainer.accelerator, (TPUAccelerator, MPSAccelerator)): # patch torch.stft to use cpu - LOG.warning("Using TPU. Patching torch.stft to use cpu.") + LOG.warning("Using TPU/MPS. Patching torch.stft to use cpu.") def stft( input: torch.Tensor, diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index 441bec41..5f76eae1 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -13,6 +13,7 @@ import numpy as np import requests import torch +import torch.backends.mps from cm_time import timer from fairseq import checkpoint_utils from fairseq.models.hubert.hubert import HubertModel @@ -28,6 +29,8 @@ def get_optimal_device(index: int = 0) -> torch.device: if torch.cuda.is_available(): return torch.device(f"cuda:{index % torch.cuda.device_count()}") + elif torch.backends.mps.is_available(): + return torch.device("mps") else: try: import torch_xla.core.xla_model as xm # noqa