Skip to content

Commit

Permalink
Revert "Checkpoints Simplified (#2059)" (#2070)
Browse files Browse the repository at this point in the history
This reverts commit b25b7f0.
  • Loading branch information
dakinggg authored Mar 15, 2023
1 parent 1f35994 commit 0ce9a13
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 156 deletions.
42 changes: 16 additions & 26 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class Trainer:
.. seealso:: :mod:`composer.utils.reproducibility` for more details on reproducibility.
dist_timeout (float, optional): Timeout, in seconds, for initializing the distributed process group.
(default: ``600.0``)
(default: ``1800.0``)
ddp_sync_strategy (str | DDPSyncStrategy, optional): The strategy to use for synchronizing gradients.
Leave unset to let the trainer auto-configure this. See :class:`.DDPSyncStrategy`
for more details.
Expand Down Expand Up @@ -907,7 +907,7 @@ def __init__(
deterministic_mode: bool = False,

# Distributed Training
dist_timeout: float = 600.0,
dist_timeout: float = 1800.0,
ddp_sync_strategy: Optional[Union[str, DDPSyncStrategy]] = None,

# Profiling
Expand Down Expand Up @@ -1456,42 +1456,35 @@ def _get_autoresume_checkpoint(
f'Looking for autoresume checkpoint: {save_latest_remote_file_name} (remote), {latest_checkpoint_path} (local)'
)

# DeepSpeed and FSDP sharded state dict save checkpoints on every rank
if self.deepspeed_enabled or self.state.fsdp_sharded_state_dict_enabled:
if self.deepspeed_enabled:
# If latest checkpoint is not saved locally, try to fetch from loggers
if not os.path.exists(latest_checkpoint_path):
log.debug(f'Attempting to download the checkpoint on to rank {dist.get_global_rank()}')
os.makedirs(save_folder, exist_ok=True)
self._try_checkpoint_download(latest_checkpoint_path, save_latest_remote_file_name, loggers,
load_progress_bar)

# List of whether the checkpoint exists on each rank
# list of whether the checkpoint exists on each rank
latest_checkpoint_exists = dist.all_gather_object(os.path.exists(latest_checkpoint_path))

if all(latest_checkpoint_exists): # All paths exist, so return the path.
return latest_checkpoint_path
# Require all ranks to have their own local checkpoint if we wish to restore from it for
# deepspeed or fsdp + sharding
elif any(latest_checkpoint_exists): # Some but not all exist, which is very bad.
# Require all ranks to have their own local checkpoint if we wish to restore from it for deepspeed
if not all(latest_checkpoint_exists):
missing_ranks = [n for (n, exist) in enumerate(latest_checkpoint_exists) if not exist]
mode = 'Deepspeed' if self.deepspeed_enabled else 'FSDP sharding'
raise RuntimeError(f'{mode} was enabled, but checkpoints missing on ranks: {missing_ranks}')
else: # None of the paths exists, so no autoresume necessary.
return None
raise RuntimeError(f'Deepspeed was enabled, but checkpoints missing on ranks: {missing_ranks}')

# Otherwise, only local rank 0 saves checkpoints
return latest_checkpoint_path
else:
# Broadcast the local checkpoint path to all ranks
# broadcast the local checkpoint path to all ranks
latest_checkpoint_path_list = [os.path.abspath(latest_checkpoint_path)]
dist.broadcast_object_list(latest_checkpoint_path_list, src=0)
latest_checkpoint_path = latest_checkpoint_path_list[0]

# Broadcast the remote checkpoint path to all ranks
# broadcast the remote checkpoint path to all ranks
save_latest_remote_file_name_list = [save_latest_remote_file_name]
dist.broadcast_object_list(save_latest_remote_file_name_list, src=0)
save_latest_remote_file_name = save_latest_remote_file_name_list[0]

# Try to download the checkpoint on local rank 0 of all nodes
# try to download the checkpoint on local rank 0 of all nodes
if dist.get_local_rank() == 0 and not os.path.exists(latest_checkpoint_path):
log.debug(f'Attempting to download the checkpoint {save_latest_remote_file_name} on to all nodes')
os.makedirs(save_folder, exist_ok=True)
Expand All @@ -1504,13 +1497,10 @@ def _get_autoresume_checkpoint(
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_autoresume')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads. Instead, we busy wait for the signal file to ensure
# synchronization intra-node before the collective call for inter-node synchronization.
dist.local_rank_zero_download_and_wait(signal_file_path)
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()
# avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads
with dist.local_rank_zero_download_and_wait(signal_file_path):
dist.barrier()

# At this point the rank 0 filepath should exist on all ranks if the download succeeded
# list of whether the checkpoint exists on each rank
Expand All @@ -1524,7 +1514,7 @@ def _get_autoresume_checkpoint(
# If the checkpoint doesn't exist on rank 0, don't crash, so the initial autoresume run can succeed
return None
elif not all(latest_checkpoint_exists):
raise RuntimeError('Downloading the checkpoint to all nodes failed when using autoresume.')
raise RuntimeError('Downloading the checkpoint to all nodes failed')

return latest_checkpoint_path

Expand Down
Loading

0 comments on commit 0ce9a13

Please sign in to comment.