From 14c83332d944fd636ea49173852ad0e504345365 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Sun, 9 Apr 2023 00:46:39 +0900 Subject: [PATCH] fix(train): specify dataformats --- src/so_vits_svc_fork/train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 1044ce4b..2209de6b 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -13,6 +13,7 @@ from torch.cuda.amp import autocast from torch.nn import functional as F from torch.utils.data import DataLoader +from torch.utils.tensorboard.writer import SummaryWriter import so_vits_svc_fork.f0 import so_vits_svc_fork.modules.commons as commons @@ -196,14 +197,16 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any): def configure_optimizers(self): return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d] - def log_image_dict(self, image_dict: dict[str, Any]) -> None: + def log_image_dict( + self, image_dict: dict[str, Any], dataformats: str = "HWC" + ) -> None: if not isinstance(self.logger, TensorBoardLogger): warnings.warn("Image logging is only supported with TensorBoardLogger.") return - writer = self.logger.experiment + writer: SummaryWriter = self.logger.experiment for k, v in image_dict.items(): try: - writer.add_image(k, v, self.global_step) + writer.add_image(k, v, self.global_step, dataformats=dataformats) except Exception as e: warnings.warn(f"Failed to log image {k}: {e}") @@ -211,7 +214,7 @@ def log_audio_dict(self, audio_dict: dict[str, Any]) -> None: if not isinstance(self.logger, TensorBoardLogger): warnings.warn("Audio logging is only supported with TensorBoardLogger.") return - writer = self.logger.experiment + writer: SummaryWriter = self.logger.experiment for k, v in audio_dict.items(): writer.add_audio( k, v, self.global_step, sample_rate=self.hparams.data.sampling_rate