From 473a3c50f6d06558dfce79d36d03861c008b2549 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 24 Oct 2020 11:16:28 +0100 Subject: [PATCH] Removed model checkpoint code, added barrier to trainer to enforce we syncronize and wait for all processes to finish before completing training --- pytorch_lightning/accelerators/ddp_accelerator.py | 1 - pytorch_lightning/callbacks/model_checkpoint.py | 12 ------------ pytorch_lightning/trainer/trainer.py | 1 + 3 files changed, 1 insertion(+), 13 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 79bc89553c6ea..e9566dc930a67 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -145,7 +145,6 @@ def train(self): results = self.ddp_train(process_idx=self.task_idx, model=model) if 'WORLD_SIZE' in os.environ: del os.environ['WORLD_SIZE'] - self.barrier('ddp_end_train') return results def training_step(self, args): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 766df5ca3b6e5..6c6a1741c31c5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -29,7 +29,6 @@ import numpy as np import torch -import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info @@ -196,9 +195,6 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - def on_fit_end(self, trainer, pl_module) -> None: - self._sync_best_model_across_procs(trainer) - def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. @@ -238,14 +234,6 @@ def save_checkpoint(self, trainer, pl_module): # Mode 2: save the last checkpoint self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) - def _sync_best_model_across_procs(self, trainer) -> None: - if trainer.accelerator_backend and torch_distrib.is_initialized(): - best_model_path, best_model_score = trainer.accelerator_backend.broadcast((self.best_model_path, - self.best_model_score)) - # track the best model path and score rank 0 - self.best_model_path = best_model_path - self.best_model_score = best_model_score - def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 44250ae905aba..ff4b29160d531 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -438,6 +438,7 @@ def fit( results = self.accelerator_backend.train() self.accelerator_backend.teardown() + self.accelerator_backend.barrier() # ---------------------------- # POST-Training CLEAN UP