Skip to content

Commit

Permalink
Removed model checkpoint code, added barrier to trainer to enforce we…
Browse files Browse the repository at this point in the history
… syncronize and wait for all processes to finish before completing training
  • Loading branch information
SeanNaren committed Oct 24, 2020
1 parent b4d33d3 commit 473a3c5
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 13 deletions.
1 change: 0 additions & 1 deletion pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def train(self):
results = self.ddp_train(process_idx=self.task_idx, model=model)
if 'WORLD_SIZE' in os.environ:
del os.environ['WORLD_SIZE']
self.barrier('ddp_end_train')
return results

def training_step(self, args):
Expand Down
12 changes: 0 additions & 12 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import numpy as np
import torch
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info
Expand Down Expand Up @@ -196,9 +195,6 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]

def on_fit_end(self, trainer, pl_module) -> None:
self._sync_best_model_across_procs(trainer)

def save_checkpoint(self, trainer, pl_module):
"""
Performs the main logic around saving a checkpoint.
Expand Down Expand Up @@ -238,14 +234,6 @@ def save_checkpoint(self, trainer, pl_module):
# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)

def _sync_best_model_across_procs(self, trainer) -> None:
if trainer.accelerator_backend and torch_distrib.is_initialized():
best_model_path, best_model_score = trainer.accelerator_backend.broadcast((self.best_model_path,
self.best_model_score))
# track the best model path and score rank 0
self.best_model_path = best_model_path
self.best_model_score = best_model_score

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def fit(

results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
self.accelerator_backend.barrier()

# ----------------------------
# POST-Training CLEAN UP
Expand Down

0 comments on commit 473a3c5

Please sign in to comment.