Skip to content

Commit

Permalink
Fix deadlock in checkpoint plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffpicard committed Dec 18, 2024
1 parent 56b2a9e commit a2edb9e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 0 additions & 2 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

@property
def attach_to_all_processes(self) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,9 @@ def train_custom(

# forward and backward for batch
for batch_step_no, batch_step in enumerate(batch_steps):
enable_gradient_sync = multi_gpu and batch_step_no == len(batch_steps) - 1
sync_context = self.ddp_model.no_sync() if enable_gradient_sync else contextlib.nullcontext()
with sync_context:
disable_gradient_sync = multi_gpu and batch_step_no < len(batch_steps) - 1
grad_sync = self.ddp_model.no_sync() if disable_gradient_sync else contextlib.nullcontext()
with grad_sync:
# forward pass
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
if multi_gpu:
Expand Down Expand Up @@ -691,7 +691,8 @@ def wrapped_forward_loss(*args, **kwargs2):

# do the optimizer step
if multi_gpu:
self._scale_gradients(torch.distributed.get_world_size()) # DDP averages across processes but we want the sum
# DDP averages across processes but we want the sum
self._scale_gradients(torch.distributed.get_world_size())
scaler.unscale_(self.optimizer)
if max_grad_norm is not None:
gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
Expand Down

0 comments on commit a2edb9e

Please sign in to comment.