Skip to content

Commit

Permalink
fix(train): specify dataformats
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Apr 8, 2023
1 parent a1b7552 commit 14c8333
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

import so_vits_svc_fork.f0
import so_vits_svc_fork.modules.commons as commons
Expand Down Expand Up @@ -196,22 +197,24 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any):
def configure_optimizers(self):
return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d]

def log_image_dict(self, image_dict: dict[str, Any]) -> None:
def log_image_dict(
self, image_dict: dict[str, Any], dataformats: str = "HWC"
) -> None:
if not isinstance(self.logger, TensorBoardLogger):
warnings.warn("Image logging is only supported with TensorBoardLogger.")
return
writer = self.logger.experiment
writer: SummaryWriter = self.logger.experiment
for k, v in image_dict.items():
try:
writer.add_image(k, v, self.global_step)
writer.add_image(k, v, self.global_step, dataformats=dataformats)
except Exception as e:
warnings.warn(f"Failed to log image {k}: {e}")

def log_audio_dict(self, audio_dict: dict[str, Any]) -> None:
if not isinstance(self.logger, TensorBoardLogger):
warnings.warn("Audio logging is only supported with TensorBoardLogger.")
return
writer = self.logger.experiment
writer: SummaryWriter = self.logger.experiment
for k, v in audio_dict.items():
writer.add_audio(
k, v, self.global_step, sample_rate=self.hparams.data.sampling_rate
Expand Down

0 comments on commit 14c8333

Please sign in to comment.