From 963e48244cb2c7a9ee180cc4a2b10ce6770b7f3d Mon Sep 17 00:00:00 2001 From: anaprietonem Date: Mon, 9 Dec 2024 10:14:01 +0000 Subject: [PATCH] cleaning --- .../training/diagnostics/callbacks/checkpoint.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py index 6ea85d78..82b5c1b4 100644 --- a/src/anemoi/training/diagnostics/callbacks/checkpoint.py +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -77,20 +77,11 @@ def model_metadata(self, model: torch.nn.Module) -> dict: return self._model_metadata - def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - """Save a checkpoint at the end of the validation stage.""" - del pl_module - if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): - monitor_candidates = self._monitor_candidates(trainer) - if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: - self._save_topk_checkpoint(trainer, monitor_candidates) - self._save_last_checkpoint(trainer, monitor_candidates) - def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: del pl_module if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) - # Need to correct the checkpoint epoch to the last epoch + # PTL advances one epoch at end of training, Need to correct the checkpoint epoch to the last epoch monitor_candidates["epoch"] = trainer.current_epoch - 1 self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates)