diff --git a/src/train.py b/src/train.py index 4adbcf442..4ee230f7f 100644 --- a/src/train.py +++ b/src/train.py @@ -91,7 +91,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("test"): log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": + if ckpt_path: log.warning("Best ckpt not found! Using current weights for testing...") ckpt_path = None trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)