diff --git a/pypots/base.py b/pypots/base.py index 6fb7d1a6..ac3287c5 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -450,6 +450,8 @@ class BaseNNModel(BaseModel): The criteria to judge whether the model's performance is the best so far. Usually the lower, the better. + best_epoch : int, default = -1, + The epoch number when the best loss is got. Notes ----- @@ -494,8 +496,8 @@ def __init__( self.model = None self.optimizer = None self.best_model_dict = None - # WDU: may enable users to customize the criteria in the future self.best_loss = float("inf") + self.best_epoch = -1 def _print_model_size(self) -> None: """Print the number of trainable parameters in the initialized NN model.""" diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 19f73d2b..50bd5afd 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -337,6 +337,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -376,7 +377,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index b0bb3336..47f70a18 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -336,6 +336,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -369,7 +370,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 2eff7647..e6b8c23f 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -296,6 +296,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -335,7 +336,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index eafbfddb..1d3eaa73 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -309,6 +309,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -348,7 +349,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 1ece900c..2cdf641d 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -331,6 +331,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -370,7 +371,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 30a87a42..284d1af2 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -334,6 +334,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -373,7 +374,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 62911931..b30e8de9 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -283,6 +283,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -322,7 +323,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index e1bf5120..9d33d275 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -313,6 +313,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -352,7 +353,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index cb90f092..89a674f3 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -330,6 +330,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience @@ -369,7 +370,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self,