diff --git a/src/so_vits_svc_fork/logger.py b/src/so_vits_svc_fork/logger.py index 3d271db4..caa9110e 100644 --- a/src/so_vits_svc_fork/logger.py +++ b/src/so_vits_svc_fork/logger.py @@ -21,19 +21,34 @@ def init_logger() -> None: if LOGGER_INIT: return - IN_COLAB = os.getenv("COLAB_RELEASE_TAG") IS_TEST = "test" in Path.cwd().stem - + package_name = sys.modules[__name__].__package__ basicConfig( level=INFO, format="%(asctime)s %(message)s", datefmt="[%X]", handlers=[ - RichHandler() if not IN_COLAB else StreamHandler(), - FileHandler(f"{__name__.split('.')[0]}.log"), + StreamHandler() if is_notebook() else RichHandler(), + FileHandler(f"{package_name}.log"), ], ) if IS_TEST: - getLogger(sys.modules[__name__].__package__).setLevel(DEBUG) + getLogger(package_name).setLevel(DEBUG) captureWarnings(True) LOGGER_INIT = True + + +def is_notebook(): + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + raise ImportError("console") + return False + if "VSCODE_PID" in os.environ: # pragma: no cover + raise ImportError("vscode") + return False + except Exception: + return False + else: # pragma: no cover + return True diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index e3dce7ac..6bc98ae0 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -21,6 +21,7 @@ from . import utils from .dataset import TextAudioCollate, TextAudioDataset +from .logger import is_notebook from .modules.descriminators import MultiPeriodDiscriminator from .modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss from .modules.mel_processing import mel_spectrogram_torch @@ -85,7 +86,7 @@ def train( if hparams.train.get("bf16_run", False) else 32, strategy=strategy, - callbacks=[pl.callbacks.RichProgressBar()], + callbacks=[pl.callbacks.RichProgressBar()] if is_notebook() else None, ) model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) trainer.fit(model, datamodule=datamodule)