diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 1044ce4b..2209de6b 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -13,6 +13,7 @@ from torch.cuda.amp import autocast from torch.nn import functional as F from torch.utils.data import DataLoader +from torch.utils.tensorboard.writer import SummaryWriter import so_vits_svc_fork.f0 import so_vits_svc_fork.modules.commons as commons @@ -196,14 +197,16 @@ def __init__(self, reset_optimizer: bool = False, **hparams: Any): def configure_optimizers(self): return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d] - def log_image_dict(self, image_dict: dict[str, Any]) -> None: + def log_image_dict( + self, image_dict: dict[str, Any], dataformats: str = "HWC" + ) -> None: if not isinstance(self.logger, TensorBoardLogger): warnings.warn("Image logging is only supported with TensorBoardLogger.") return - writer = self.logger.experiment + writer: SummaryWriter = self.logger.experiment for k, v in image_dict.items(): try: - writer.add_image(k, v, self.global_step) + writer.add_image(k, v, self.global_step, dataformats=dataformats) except Exception as e: warnings.warn(f"Failed to log image {k}: {e}") @@ -211,7 +214,7 @@ def log_audio_dict(self, audio_dict: dict[str, Any]) -> None: if not isinstance(self.logger, TensorBoardLogger): warnings.warn("Audio logging is only supported with TensorBoardLogger.") return - writer = self.logger.experiment + writer: SummaryWriter = self.logger.experiment for k, v in audio_dict.items(): writer.add_audio( k, v, self.global_step, sample_rate=self.hparams.data.sampling_rate