diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index ed2f74f9c3..8ea14de60c 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -803,7 +803,7 @@ class Trainer: .. seealso:: :mod:`composer.utils.reproducibility` for more details on reproducibility. dist_timeout (float, optional): Timeout, in seconds, for initializing the distributed process group. - (default: ``600.0``) + (default: ``1800.0``) ddp_sync_strategy (str | DDPSyncStrategy, optional): The strategy to use for synchronizing gradients. Leave unset to let the trainer auto-configure this. See :class:`.DDPSyncStrategy` for more details. @@ -907,7 +907,7 @@ def __init__( deterministic_mode: bool = False, # Distributed Training - dist_timeout: float = 600.0, + dist_timeout: float = 1800.0, ddp_sync_strategy: Optional[Union[str, DDPSyncStrategy]] = None, # Profiling diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 3489f97994..2420e0efbf 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -38,6 +38,15 @@ _DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name. +def _format_path_with_rank_zero(path: str) -> str: + """Formats ``path`` with the rank zero values.""" + return path.format( + rank=0, + local_rank=0, + node_rank=0, + ) + + def _format_path_with_current_rank(path: str) -> str: """Formats ``path`` formatted with the current rank values.""" return path.format( @@ -184,27 +193,31 @@ def load_checkpoint( """ # download the checkpoint to the node-local folder log.debug('Loading checkpoint at %s', path) - # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node - tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None) + # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node. + # If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder. + tempdir_ctx = (tempfile.TemporaryDirectory() if (state.fsdp_sharded_state_dict_enabled or + dist.get_local_rank() == 0) else contextlib.nullcontext(None)) with tempdir_ctx as tempdir: try: - # Get path to temporary folder for node-local checkpoints created on rank 0 - node_checkpoint_folder = _get_local_rank_zero_path(tempdir) + # Get the path to the proper checkpoint folder corresponding to the current rank's node. + # If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir. + node_checkpoint_folder = (tempdir if state.fsdp_sharded_state_dict_enabled else + _get_node_checkpoint_download_folder(tempdir)) assert node_checkpoint_folder is not None - composer_states_filepath, extracted_tar_checkpoint_folder = download_checkpoint( + composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint( path=path, node_checkpoint_folder=node_checkpoint_folder, object_store=object_store, progress_bar=progress_bar, fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled, - deepspeed_checkpoint=is_model_deepspeed(state.model), ) rng_state_dicts = _restore_checkpoint( state, logger, composer_states_filepath, - extracted_tar_checkpoint_folder, + extracted_rank_n, + extracted_checkpoint_folder, load_weights_only=load_weights_only, strict_model_weights=strict_model_weights, ignore_keys=ignore_keys, @@ -220,7 +233,7 @@ def load_checkpoint( return rng_state_dicts -def _get_local_rank_zero_path(path: Optional[str]) -> str: +def _get_node_checkpoint_download_folder(path: Optional[str]) -> str: """Broadcasts the ``path`` from the LOCAL rank zero to all LOCAL ranks.""" local_rank_zero = dist.get_local_world_size() * dist.get_node_rank() paths = dist.all_gather_object(path) @@ -235,77 +248,88 @@ def download_checkpoint( object_store: Optional[Union[ObjectStore, LoggerDestination]], progress_bar: bool, fsdp_sharded_state_dict_enabled: bool = False, - deepspeed_checkpoint: bool = False, -) -> Tuple[str, Optional[str]]: +) -> Tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. - Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``). + Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``). * The ``composer_states_filepath``, is the path to the composer states, which can be passed into :meth:`torch.load`. - * The ``extracted_tar_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into + * The ``extracted_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into :meth:`deepspeed.DeepSpeedEngine.load_checkpoint`. + * The ``extracted_rank_n`` is a boolean flag indicating whether a tarball was extracted on global + rank greater than 0. """ log.debug('Downloading checkpoint to folder %s', node_checkpoint_folder) - # Files do not have extensions as it could be .tar, .pt, or something else + rank_zero_checkpoint_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint') rank_n_checkpoint_filepath = os.path.join(node_checkpoint_folder, f'rank{dist.get_global_rank()}_checkpoint') - extracted_tar_checkpoint_folder = None - - # DeepSpeed checkpoints must be tarballs - if deepspeed_checkpoint and not is_tar(path): - raise ValueError(f'Checkpoint at {path} is not a tarball, which is needed for DeepSpeed checkpoint') - - # Determine where checkpoint will be after downloading + extracted_checkpoint_folder = None + extracted_rank_n = False if is_tar(path): - # DeepSpeed checkpoints OR checkpoints of all formats with compression - extracted_tar_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint') - composer_states_filepath = os.path.join(extracted_tar_checkpoint_folder, _COMPOSER_STATES_FILENAME) - elif fsdp_sharded_state_dict_enabled: - # FSDP sharded state dict has a Composer state dict per rank - composer_states_filepath = rank_n_checkpoint_filepath + extracted_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint') + composer_states_filepath = os.path.join(extracted_checkpoint_folder, _COMPOSER_STATES_FILENAME) else: - # Vanilla Composer has a single Composer state dict on local rank zero - composer_states_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint') - - local_path = _format_path_with_current_rank(path) - local_rank_zero_path = _get_local_rank_zero_path(local_path) + # it's not an archive; it's just the composer state dict + # and only rank zero has this file unless fsdp_sharded_state_dict_enabled then + # every rank has it's own file. + extracted_checkpoint_folder = None + composer_states_filepath = (rank_n_checkpoint_filepath + if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath) try: - # Download on local rank 0 or if path is different from local rank 0 - if dist.get_local_rank() == 0 or local_path != local_rank_zero_path: - get_file_succeeded = True + if ((fsdp_sharded_state_dict_enabled and dist.get_global_rank() == 0) or + (not fsdp_sharded_state_dict_enabled and dist.get_local_rank() == 0)): + # every NODE needs the GLOBAL rank zero checkpoint unless fsdp_sharded_state_dict_enabled. + path = _format_path_with_rank_zero(path) + get_file(destination=rank_zero_checkpoint_filepath, + path=path, + object_store=object_store, + progress_bar=progress_bar) + if extracted_checkpoint_folder is not None: + try: + with tarfile.open(rank_zero_checkpoint_filepath) as tarball: + tarball.extractall(extracted_checkpoint_folder) + except FileNotFoundError: + # Not re-raising the file-not-found error as that is irrelevant; + # the underlying issue is that the checkpoint file does not exist on the disk + # or could not be downloaded + raise RuntimeError(f'Checkpoint {path} does not exist') + + if rank_zero_checkpoint_filepath != rank_n_checkpoint_filepath: + # every RANK needs ITS OWN checkpoint. + # But, the global rank zero is a special case -- these files are the same! + assert dist.get_global_rank() != 0, 'invariant violation' + try: - get_file( - destination=rank_n_checkpoint_filepath, - path=_format_path_with_current_rank(path), - object_store=object_store, - progress_bar=progress_bar, - ) - except FileNotFoundError as e: - get_file_succeeded = False - # If the checkpoint is not found, raise the error only on local rank zero, which - # requires a checkpoint, or if FSDP sharded checkpoints or deepspeed are enabled, - # which require a checkpoint on all ranks - if dist.get_local_rank() == 0 or fsdp_sharded_state_dict_enabled or deepspeed_checkpoint: - raise e - # Otherwise, ignore error as standard checkpointing does not have rank-local checkpoints + get_file(destination=rank_n_checkpoint_filepath, + path=_format_path_with_current_rank(path), + object_store=object_store, + progress_bar=progress_bar) + except FileNotFoundError: + # Allowing not-found errors to be ignored as sometimes there won't be rank-local checkpoints + # (e.g. when not using deepspeed nor using fsdp sharded checkpoints) pass - # Extract tarballs, which happens for DeepSpeed or compression - if get_file_succeeded and extracted_tar_checkpoint_folder is not None: - with tarfile.open(rank_n_checkpoint_filepath) as tarball: - tarball.extractall(extracted_tar_checkpoint_folder) + if extracted_checkpoint_folder is not None: + try: + # it's an archive and needs to be extracted + with tarfile.open(rank_n_checkpoint_filepath) as tarball: + tarball.extractall(extracted_checkpoint_folder) + extracted_rank_n = True + except FileNotFoundError: + # this will happen most of the time (i.e. whenever deepspeed + # is not being used) so not logging anything + pass + finally: - # Wait for all checkpoints on the node to finish downloading. First, we busy wait until - # file exists, ensuring synchronization intra-node. Next, we use `dist.barrier()` to - # sync across nodes. This is necessary to avoid timing out on the barrier when downloading - # large checkpoints, which may exceed the normal timeout. Now, a timeout is only - # encountered if the difference between checkpoint download times on the slowest and - # fastest nodes exceeds the timeout. - dist.local_rank_zero_download_and_wait(composer_states_filepath) + # Wait for all checkpoints on the node to finish downloading + # Putting the barrier in a finally so the rank will always block on the barrier, + # even if it has an exception. + # Any exception will be re-raised after the barrier passes. The launcher script + # will detect the process crash and terminate the other ranks dist.barrier() - return composer_states_filepath, extracted_tar_checkpoint_folder + return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n def _flatten_keys(obj: Any, paths: List[str], existing_path: str): @@ -399,7 +423,8 @@ def _restore_checkpoint( state: State, logger: Logger, composer_states_filepath: str, - extracted_tar_checkpoint_folder: Optional[str], + extracted_rank_n: bool, + extracted_checkpoint_folder: Optional[str], load_weights_only: bool, strict_model_weights: bool, ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]], @@ -407,6 +432,7 @@ def _restore_checkpoint( algorithm_passes: Optional[List[AlgorithmPass]], ) -> Optional[List[Dict[str, Any]]]: """Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False).""" + # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = safe_torch_load(composer_states_filepath) if ignore_keys: # Filter provided list of key paths @@ -417,11 +443,15 @@ def _restore_checkpoint( log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}") if is_model_deepspeed(state.model): - if extracted_tar_checkpoint_folder is None: + if extracted_checkpoint_folder is None: raise RuntimeError('Deepspeed checkpoints require a tarball, not a weights file.') + global_rank = dist.get_global_rank() + if global_rank > 0 and not extracted_rank_n: + raise RuntimeError(f'Deepspeed checkpoint missing for rank {global_rank}') + load_path, _ = state.deepspeed_model.load_checkpoint( - extracted_tar_checkpoint_folder, + extracted_checkpoint_folder, tag=_DEEPSPEED_TAG, load_module_only=load_weights_only, load_module_strict=strict_model_weights, @@ -452,6 +482,7 @@ def save_checkpoint( *, weights_only: bool = False, ) -> Union[str, None]: # noqa: D103 + log.debug('Saving checkpoint to %s', filename) is_deepspeed = is_model_deepspeed(state.model) @@ -468,7 +499,7 @@ def save_checkpoint( if dirname: os.makedirs(dirname, exist_ok=True) - # Only rank 0 saves the state_dict unless state.fsdp_sharded_state_dict_enabled=True + # only rank 0 saves the state_dict unless state.fsdp_sharded_state_dict_enabled=True. if dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled: with open(save_filename, 'wb') as f: torch.save(state_dict, f) @@ -476,7 +507,7 @@ def save_checkpoint( if is_tar(save_filename): _compress_file(save_filename, basename=_COMPOSER_STATES_FILENAME) - # All ranks save for deepspeed + # all ranks save for deepspeed if is_deepspeed: _save_deepspeed_model(state.deepspeed_model, save_filename) @@ -485,7 +516,9 @@ def save_checkpoint( if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled: assert os.path.exists(save_filename), 'Expected file to have been saved.' return save_filename - return None + else: + # no file saved + return None def _compress_file(filename: str, basename: str): diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 20724449ae..a570d045cb 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -351,7 +351,7 @@ def is_initialized(): return dist.is_initialized() -def initialize_dist(device: Union[str, Device], timeout: float = 600.0): +def initialize_dist(device: Union[str, Device], timeout: float = 300.0): """Initialize the default PyTorch distributed process group. This function assumes that the following environment variables are set: @@ -374,7 +374,7 @@ def initialize_dist(device: Union[str, Device], timeout: float = 600.0): interpreted. Either a string corresponding to a device (one of ``'cpu'``, ``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`. timeout (float, optional): The timeout for operations executed against the process - group, expressed in seconds. (default: ``600.0``). + group, expressed in seconds. (default: ``300.0``). """ # If device is string, get corresponding composer.devices.Device object device_obj = get_device(device) diff --git a/composer/utils/inference.py b/composer/utils/inference.py index 1cb3da11c7..13bfd563f4 100644 --- a/composer/utils/inference.py +++ b/composer/utils/inference.py @@ -168,10 +168,10 @@ def export_for_inference( # download checkpoint and load weights only log.debug('Loading checkpoint at %s', load_path) with tempfile.TemporaryDirectory() as tempdir: - composer_states_filepath, _ = download_checkpoint(path=load_path, - node_checkpoint_folder=tempdir, - object_store=load_object_store, - progress_bar=True) + composer_states_filepath, _, _ = download_checkpoint(path=load_path, + node_checkpoint_folder=tempdir, + object_store=load_object_store, + progress_bar=True) state_dict = safe_torch_load(composer_states_filepath) missing_keys, unexpected_keys = model.load_state_dict(state_dict['state']['model'], strict=load_strict) if len(missing_keys) > 0: diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py index 73f54b330c..667b6fa81a 100644 --- a/tests/utils/test_autolog_hparams.py +++ b/tests/utils/test_autolog_hparams.py @@ -169,7 +169,7 @@ def test_extract_hparams_trainer(): 'deterministic_mode': False, # Distributed Training - 'dist_timeout': 600.0, + 'dist_timeout': 1800.0, 'ddp_sync_strategy': None, # Profiling