Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/34j/so-vits-svc-fork into f…
Browse files Browse the repository at this point in the history
…ix/segment-size
  • Loading branch information
34j committed Apr 16, 2023
2 parents 04972ea + da928aa commit da02f4b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import lightning.pytorch as pl
import torch
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
from torch.cuda.amp import autocast
Expand Down Expand Up @@ -175,10 +175,10 @@ def on_train_start(self) -> None:
global_step = total_batch_idx * self.optimizers_count
self.set_global_step(global_step)

# check if using tpu
if isinstance(self.trainer.accelerator, TPUAccelerator):
# check if using tpu or mps
if isinstance(self.trainer.accelerator, (TPUAccelerator, MPSAccelerator)):
# patch torch.stft to use cpu
LOG.warning("Using TPU. Patching torch.stft to use cpu.")
LOG.warning("Using TPU/MPS. Patching torch.stft to use cpu.")

def stft(
input: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import requests
import torch
import torch.backends.mps
from cm_time import timer
from fairseq import checkpoint_utils
from fairseq.models.hubert.hubert import HubertModel
Expand All @@ -28,6 +29,8 @@
def get_optimal_device(index: int = 0) -> torch.device:
if torch.cuda.is_available():
return torch.device(f"cuda:{index % torch.cuda.device_count()}")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
try:
import torch_xla.core.xla_model as xm # noqa
Expand Down

0 comments on commit da02f4b

Please sign in to comment.