From 0ce9a13a245d8c021ed1c860f3121d87df392561 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 14 Mar 2023 21:05:18 -0700 Subject: [PATCH] Revert "Checkpoints Simplified (#2059)" (#2070) This reverts commit b25b7f03ae38843419adc23cdcf45cfc2a0557dc. --- composer/trainer/trainer.py | 42 ++---- composer/utils/checkpoint.py | 183 +++++++++++++---------- composer/utils/dist.py | 4 +- composer/utils/inference.py | 8 +- tests/trainer/test_sharded_checkpoint.py | 47 +----- tests/utils/test_autolog_hparams.py | 2 +- 6 files changed, 130 insertions(+), 156 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index b2dfe6b4bf8..07411def850 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -803,7 +803,7 @@ class Trainer: .. seealso:: :mod:`composer.utils.reproducibility` for more details on reproducibility. dist_timeout (float, optional): Timeout, in seconds, for initializing the distributed process group. - (default: ``600.0``) + (default: ``1800.0``) ddp_sync_strategy (str | DDPSyncStrategy, optional): The strategy to use for synchronizing gradients. Leave unset to let the trainer auto-configure this. See :class:`.DDPSyncStrategy` for more details. @@ -907,7 +907,7 @@ def __init__( deterministic_mode: bool = False, # Distributed Training - dist_timeout: float = 600.0, + dist_timeout: float = 1800.0, ddp_sync_strategy: Optional[Union[str, DDPSyncStrategy]] = None, # Profiling @@ -1456,8 +1456,7 @@ def _get_autoresume_checkpoint( f'Looking for autoresume checkpoint: {save_latest_remote_file_name} (remote), {latest_checkpoint_path} (local)' ) - # DeepSpeed and FSDP sharded state dict save checkpoints on every rank - if self.deepspeed_enabled or self.state.fsdp_sharded_state_dict_enabled: + if self.deepspeed_enabled: # If latest checkpoint is not saved locally, try to fetch from loggers if not os.path.exists(latest_checkpoint_path): log.debug(f'Attempting to download the checkpoint on to rank {dist.get_global_rank()}') @@ -1465,33 +1464,27 @@ def _get_autoresume_checkpoint( self._try_checkpoint_download(latest_checkpoint_path, save_latest_remote_file_name, loggers, load_progress_bar) - # List of whether the checkpoint exists on each rank + # list of whether the checkpoint exists on each rank latest_checkpoint_exists = dist.all_gather_object(os.path.exists(latest_checkpoint_path)) - if all(latest_checkpoint_exists): # All paths exist, so return the path. - return latest_checkpoint_path - # Require all ranks to have their own local checkpoint if we wish to restore from it for - # deepspeed or fsdp + sharding - elif any(latest_checkpoint_exists): # Some but not all exist, which is very bad. + # Require all ranks to have their own local checkpoint if we wish to restore from it for deepspeed + if not all(latest_checkpoint_exists): missing_ranks = [n for (n, exist) in enumerate(latest_checkpoint_exists) if not exist] - mode = 'Deepspeed' if self.deepspeed_enabled else 'FSDP sharding' - raise RuntimeError(f'{mode} was enabled, but checkpoints missing on ranks: {missing_ranks}') - else: # None of the paths exists, so no autoresume necessary. - return None + raise RuntimeError(f'Deepspeed was enabled, but checkpoints missing on ranks: {missing_ranks}') - # Otherwise, only local rank 0 saves checkpoints + return latest_checkpoint_path else: - # Broadcast the local checkpoint path to all ranks + # broadcast the local checkpoint path to all ranks latest_checkpoint_path_list = [os.path.abspath(latest_checkpoint_path)] dist.broadcast_object_list(latest_checkpoint_path_list, src=0) latest_checkpoint_path = latest_checkpoint_path_list[0] - # Broadcast the remote checkpoint path to all ranks + # broadcast the remote checkpoint path to all ranks save_latest_remote_file_name_list = [save_latest_remote_file_name] dist.broadcast_object_list(save_latest_remote_file_name_list, src=0) save_latest_remote_file_name = save_latest_remote_file_name_list[0] - # Try to download the checkpoint on local rank 0 of all nodes + # try to download the checkpoint on local rank 0 of all nodes if dist.get_local_rank() == 0 and not os.path.exists(latest_checkpoint_path): log.debug(f'Attempting to download the checkpoint {save_latest_remote_file_name} on to all nodes') os.makedirs(save_folder, exist_ok=True) @@ -1504,13 +1497,10 @@ def _get_autoresume_checkpoint( with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_autoresume') - # Avoid the collective call until the local rank zero has finished trying to download the checkpoint - # so that we don't timeout for large downloads. Instead, we busy wait for the signal file to ensure - # synchronization intra-node before the collective call for inter-node synchronization. - dist.local_rank_zero_download_and_wait(signal_file_path) - if dist.get_local_rank() == 0: - os.remove(signal_file_path) - dist.barrier() + # avoid the collective call until the local rank zero has finished trying to download the checkpoint + # so that we don't timeout for large downloads + with dist.local_rank_zero_download_and_wait(signal_file_path): + dist.barrier() # At this point the rank 0 filepath should exist on all ranks if the download succeeded # list of whether the checkpoint exists on each rank @@ -1524,7 +1514,7 @@ def _get_autoresume_checkpoint( # If the checkpoint doesn't exist on rank 0, don't crash, so the initial autoresume run can succeed return None elif not all(latest_checkpoint_exists): - raise RuntimeError('Downloading the checkpoint to all nodes failed when using autoresume.') + raise RuntimeError('Downloading the checkpoint to all nodes failed') return latest_checkpoint_path diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 45f1a643d08..2420e0efbfa 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -38,6 +38,15 @@ _DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name. +def _format_path_with_rank_zero(path: str) -> str: + """Formats ``path`` with the rank zero values.""" + return path.format( + rank=0, + local_rank=0, + node_rank=0, + ) + + def _format_path_with_current_rank(path: str) -> str: """Formats ``path`` formatted with the current rank values.""" return path.format( @@ -184,27 +193,31 @@ def load_checkpoint( """ # download the checkpoint to the node-local folder log.debug('Loading checkpoint at %s', path) - # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node - tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None) + # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node. + # If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder. + tempdir_ctx = (tempfile.TemporaryDirectory() if (state.fsdp_sharded_state_dict_enabled or + dist.get_local_rank() == 0) else contextlib.nullcontext(None)) with tempdir_ctx as tempdir: try: - # Get path to temporary folder for node-local checkpoints created on rank 0 - node_checkpoint_folder = _get_local_rank_zero_path(tempdir) + # Get the path to the proper checkpoint folder corresponding to the current rank's node. + # If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir. + node_checkpoint_folder = (tempdir if state.fsdp_sharded_state_dict_enabled else + _get_node_checkpoint_download_folder(tempdir)) assert node_checkpoint_folder is not None - composer_states_filepath, extracted_tar_checkpoint_folder = download_checkpoint( + composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint( path=path, node_checkpoint_folder=node_checkpoint_folder, object_store=object_store, progress_bar=progress_bar, fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled, - deepspeed_checkpoint=is_model_deepspeed(state.model), ) rng_state_dicts = _restore_checkpoint( state, logger, composer_states_filepath, - extracted_tar_checkpoint_folder, + extracted_rank_n, + extracted_checkpoint_folder, load_weights_only=load_weights_only, strict_model_weights=strict_model_weights, ignore_keys=ignore_keys, @@ -220,7 +233,7 @@ def load_checkpoint( return rng_state_dicts -def _get_local_rank_zero_path(path: Optional[str]) -> str: +def _get_node_checkpoint_download_folder(path: Optional[str]) -> str: """Broadcasts the ``path`` from the LOCAL rank zero to all LOCAL ranks.""" local_rank_zero = dist.get_local_world_size() * dist.get_node_rank() paths = dist.all_gather_object(path) @@ -235,92 +248,88 @@ def download_checkpoint( object_store: Optional[Union[ObjectStore, LoggerDestination]], progress_bar: bool, fsdp_sharded_state_dict_enabled: bool = False, - deepspeed_checkpoint: bool = False, -) -> Tuple[str, Optional[str]]: +) -> Tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. - Args: - path (str): The path to the checkpoint. - node_checkpoint_folder (str): The path to the node-local folder to store the checkpoint. - object_store (Optional[Union[ObjectStore, LoggerDestination]]): The object store to download the checkpoint from. - progress_bar (bool): Whether to display a progress bar. - fsdp_sharded_state_dict_enabled (bool, optional): Whether to enable FSDP sharded state dict. (default: ``False``) - deepspeed_checkpoint (bool, optional): Whether the checkpoint is a DeepSpeed checkpoint. (default: ``False``) - - Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``). + Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``). * The ``composer_states_filepath``, is the path to the composer states, which can be passed into :meth:`torch.load`. - * The ``extracted_tar_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into + * The ``extracted_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into :meth:`deepspeed.DeepSpeedEngine.load_checkpoint`. + * The ``extracted_rank_n`` is a boolean flag indicating whether a tarball was extracted on global + rank greater than 0. """ log.debug('Downloading checkpoint to folder %s', node_checkpoint_folder) - # Files do not have extensions as it could be .tar, .pt, or something else - rank_n_checkpoint_filepath = os.path.join(node_checkpoint_folder, f'rank{dist.get_local_rank()}_checkpoint') - extracted_tar_checkpoint_folder = None - - # DeepSpeed checkpoints must be tarballs - if deepspeed_checkpoint and not is_tar(path): - raise ValueError(f'Checkpoint at {path} is not a tarball, which is needed for DeepSpeed checkpoint') - - # Determine where checkpoint will be after downloading + rank_zero_checkpoint_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint') + rank_n_checkpoint_filepath = os.path.join(node_checkpoint_folder, f'rank{dist.get_global_rank()}_checkpoint') + extracted_checkpoint_folder = None + extracted_rank_n = False if is_tar(path): - # DeepSpeed checkpoints OR checkpoints of all formats with compression - extracted_tar_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint') - composer_states_filepath = os.path.join(extracted_tar_checkpoint_folder, _COMPOSER_STATES_FILENAME) - elif fsdp_sharded_state_dict_enabled: - # FSDP sharded state dict has a Composer state dict per rank - composer_states_filepath = rank_n_checkpoint_filepath + extracted_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint') + composer_states_filepath = os.path.join(extracted_checkpoint_folder, _COMPOSER_STATES_FILENAME) else: - # Vanilla Composer has a single Composer state dict on local rank zero - composer_states_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint') - - local_path = _format_path_with_current_rank(path) - local_rank_zero_path = _get_local_rank_zero_path(local_path) + # it's not an archive; it's just the composer state dict + # and only rank zero has this file unless fsdp_sharded_state_dict_enabled then + # every rank has it's own file. + extracted_checkpoint_folder = None + composer_states_filepath = (rank_n_checkpoint_filepath + if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath) try: - # Download on local rank 0 or if path is different from local rank 0 - if dist.get_local_rank() == 0 or local_path != local_rank_zero_path: - get_file_succeeded = True + if ((fsdp_sharded_state_dict_enabled and dist.get_global_rank() == 0) or + (not fsdp_sharded_state_dict_enabled and dist.get_local_rank() == 0)): + # every NODE needs the GLOBAL rank zero checkpoint unless fsdp_sharded_state_dict_enabled. + path = _format_path_with_rank_zero(path) + get_file(destination=rank_zero_checkpoint_filepath, + path=path, + object_store=object_store, + progress_bar=progress_bar) + if extracted_checkpoint_folder is not None: + try: + with tarfile.open(rank_zero_checkpoint_filepath) as tarball: + tarball.extractall(extracted_checkpoint_folder) + except FileNotFoundError: + # Not re-raising the file-not-found error as that is irrelevant; + # the underlying issue is that the checkpoint file does not exist on the disk + # or could not be downloaded + raise RuntimeError(f'Checkpoint {path} does not exist') + + if rank_zero_checkpoint_filepath != rank_n_checkpoint_filepath: + # every RANK needs ITS OWN checkpoint. + # But, the global rank zero is a special case -- these files are the same! + assert dist.get_global_rank() != 0, 'invariant violation' + try: - get_file( - path=_format_path_with_current_rank(path), - destination=rank_n_checkpoint_filepath, - object_store=object_store, - progress_bar=progress_bar, - ) - except FileNotFoundError as e: - get_file_succeeded = False - # If the checkpoint is not found, raise the error only on local rank zero, which - # requires a checkpoint, or if FSDP sharded checkpoints or deepspeed are enabled, - # which require a checkpoint on all ranks - if dist.get_local_rank() == 0 or fsdp_sharded_state_dict_enabled or deepspeed_checkpoint: - raise e - # Otherwise, ignore error as standard checkpointing does not have rank-local checkpoints + get_file(destination=rank_n_checkpoint_filepath, + path=_format_path_with_current_rank(path), + object_store=object_store, + progress_bar=progress_bar) + except FileNotFoundError: + # Allowing not-found errors to be ignored as sometimes there won't be rank-local checkpoints + # (e.g. when not using deepspeed nor using fsdp sharded checkpoints) pass - # Extract tarballs, which happens for DeepSpeed or compression - if get_file_succeeded and extracted_tar_checkpoint_folder is not None: - with tarfile.open(rank_n_checkpoint_filepath) as tarball: - tarball.extractall(extracted_tar_checkpoint_folder) - finally: - # Wait for all checkpoints on the node to finish downloading. First, we busy wait until - # file exists, ensuring synchronization intra-node. Next, we use `dist.barrier()` to - # sync across nodes. This is necessary to avoid timing out on the barrier when downloading - # large checkpoints, which may exceed the normal timeout. Now, a timeout is only - # encountered if the difference between checkpoint download times on the slowest and - # fastest nodes exceeds the timeout. - signal_file_path = local_rank_zero_path + '.local_rank0_completed' - if dist.get_local_rank() == 0: - with open(signal_file_path, 'wb') as f: - f.write(b'local_rank0_completed') - dist.local_rank_zero_download_and_wait(signal_file_path) - if dist.get_local_rank() == 0: - os.remove(signal_file_path) + if extracted_checkpoint_folder is not None: + try: + # it's an archive and needs to be extracted + with tarfile.open(rank_n_checkpoint_filepath) as tarball: + tarball.extractall(extracted_checkpoint_folder) + extracted_rank_n = True + except FileNotFoundError: + # this will happen most of the time (i.e. whenever deepspeed + # is not being used) so not logging anything + pass + finally: + # Wait for all checkpoints on the node to finish downloading + # Putting the barrier in a finally so the rank will always block on the barrier, + # even if it has an exception. + # Any exception will be re-raised after the barrier passes. The launcher script + # will detect the process crash and terminate the other ranks dist.barrier() - return composer_states_filepath, extracted_tar_checkpoint_folder + return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n def _flatten_keys(obj: Any, paths: List[str], existing_path: str): @@ -414,7 +423,8 @@ def _restore_checkpoint( state: State, logger: Logger, composer_states_filepath: str, - extracted_tar_checkpoint_folder: Optional[str], + extracted_rank_n: bool, + extracted_checkpoint_folder: Optional[str], load_weights_only: bool, strict_model_weights: bool, ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]], @@ -422,6 +432,7 @@ def _restore_checkpoint( algorithm_passes: Optional[List[AlgorithmPass]], ) -> Optional[List[Dict[str, Any]]]: """Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False).""" + # Now, all ranks load the checkpoint that local rank zero downloaded state_dict = safe_torch_load(composer_states_filepath) if ignore_keys: # Filter provided list of key paths @@ -432,11 +443,15 @@ def _restore_checkpoint( log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}") if is_model_deepspeed(state.model): - if extracted_tar_checkpoint_folder is None: + if extracted_checkpoint_folder is None: raise RuntimeError('Deepspeed checkpoints require a tarball, not a weights file.') + global_rank = dist.get_global_rank() + if global_rank > 0 and not extracted_rank_n: + raise RuntimeError(f'Deepspeed checkpoint missing for rank {global_rank}') + load_path, _ = state.deepspeed_model.load_checkpoint( - extracted_tar_checkpoint_folder, + extracted_checkpoint_folder, tag=_DEEPSPEED_TAG, load_module_only=load_weights_only, load_module_strict=strict_model_weights, @@ -451,7 +466,6 @@ def _restore_checkpoint( exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, ) - if not load_weights_only: state.load_state_dict( state_dict['state'], @@ -468,6 +482,7 @@ def save_checkpoint( *, weights_only: bool = False, ) -> Union[str, None]: # noqa: D103 + log.debug('Saving checkpoint to %s', filename) is_deepspeed = is_model_deepspeed(state.model) @@ -484,7 +499,7 @@ def save_checkpoint( if dirname: os.makedirs(dirname, exist_ok=True) - # Only rank 0 saves the state_dict unless state.fsdp_sharded_state_dict_enabled=True + # only rank 0 saves the state_dict unless state.fsdp_sharded_state_dict_enabled=True. if dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled: with open(save_filename, 'wb') as f: torch.save(state_dict, f) @@ -492,7 +507,7 @@ def save_checkpoint( if is_tar(save_filename): _compress_file(save_filename, basename=_COMPOSER_STATES_FILENAME) - # All ranks save for deepspeed + # all ranks save for deepspeed if is_deepspeed: _save_deepspeed_model(state.deepspeed_model, save_filename) @@ -501,7 +516,9 @@ def save_checkpoint( if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled: assert os.path.exists(save_filename), 'Expected file to have been saved.' return save_filename - return None + else: + # no file saved + return None def _compress_file(filename: str, basename: str): diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 20724449ae0..a570d045cb2 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -351,7 +351,7 @@ def is_initialized(): return dist.is_initialized() -def initialize_dist(device: Union[str, Device], timeout: float = 600.0): +def initialize_dist(device: Union[str, Device], timeout: float = 300.0): """Initialize the default PyTorch distributed process group. This function assumes that the following environment variables are set: @@ -374,7 +374,7 @@ def initialize_dist(device: Union[str, Device], timeout: float = 600.0): interpreted. Either a string corresponding to a device (one of ``'cpu'``, ``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`. timeout (float, optional): The timeout for operations executed against the process - group, expressed in seconds. (default: ``600.0``). + group, expressed in seconds. (default: ``300.0``). """ # If device is string, get corresponding composer.devices.Device object device_obj = get_device(device) diff --git a/composer/utils/inference.py b/composer/utils/inference.py index dff1af18065..f0af21c9bff 100644 --- a/composer/utils/inference.py +++ b/composer/utils/inference.py @@ -174,10 +174,10 @@ def export_for_inference( # download checkpoint and load weights only log.debug('Loading checkpoint at %s', load_path) with tempfile.TemporaryDirectory() as tempdir: - composer_states_filepath, _ = download_checkpoint(path=load_path, - node_checkpoint_folder=tempdir, - object_store=load_object_store, - progress_bar=True) + composer_states_filepath, _, _ = download_checkpoint(path=load_path, + node_checkpoint_folder=tempdir, + object_store=load_object_store, + progress_bar=True) state_dict = safe_torch_load(composer_states_filepath) missing_keys, unexpected_keys = model.load_state_dict(state_dict['state']['model'], strict=load_strict) if len(missing_keys) > 0: diff --git a/tests/trainer/test_sharded_checkpoint.py b/tests/trainer/test_sharded_checkpoint.py index 359e8c6045e..d9ae5d6b924 100644 --- a/tests/trainer/test_sharded_checkpoint.py +++ b/tests/trainer/test_sharded_checkpoint.py @@ -22,9 +22,7 @@ def get_trainer(save_folder=None, num_features=2, num_classes=2, fsdp_state_dict_type='full', - load_path=None, - autoresume=False, - run_name=None): + load_path=None): model = SimpleModel(num_features=num_features, num_classes=num_classes) dataset = RandomClassificationDataset(shape=(num_features, 1, 1), size=128) dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=32) @@ -42,12 +40,9 @@ def get_trainer(save_folder=None, max_duration='2ba', save_interval='2ba', save_filename=save_filename, - save_overwrite=False, load_path=load_path, progress_bar=False, log_to_console=False, - autoresume=autoresume, - run_name=run_name, ) return trainer @@ -182,31 +177,17 @@ def test_fsdp_full_state_dict_save(world_size, tmp_path: pathlib.Path): @pytest.mark.gpu @world_size(2) -@pytest.mark.parametrize('autoresume', [True, False]) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='requires PyTorch 1.13 or higher') -def test_fsdp_full_state_dict_load(world_size, tmp_path: pathlib.Path, autoresume: bool): - if autoresume: - run_name = 'my-cool-autoresume-run' - else: - run_name = None +def test_fsdp_full_state_dict_load(world_size, tmp_path: pathlib.Path): save_folder = tmp_path save_filename = 'rank{rank}.pt' - trainer1 = get_trainer(save_folder=str(save_folder), - save_filename=save_filename, - fsdp_state_dict_type='full', - run_name=run_name, - autoresume=autoresume) + trainer1 = get_trainer(save_folder=str(save_folder), save_filename=save_filename, fsdp_state_dict_type='full') trainer1.fit() state_dict_from_trainer1 = trainer1.state.state_dict() trainer1.close() load_path = str(save_folder / pathlib.Path('rank{rank}.pt')) - trainer2 = get_trainer(save_folder=str(save_folder), - save_filename=save_filename, - fsdp_state_dict_type='full', - load_path=load_path, - run_name=run_name, - autoresume=autoresume) + trainer2 = get_trainer(fsdp_state_dict_type='full', load_path=load_path) state_dict_from_trainer2 = trainer2.state.state_dict() if dist.get_global_rank() == 0: @@ -324,33 +305,19 @@ def test_fsdp_partitioned_state_dict_save(world_size, tmp_path: pathlib.Path, st @pytest.mark.gpu @world_size(2) @pytest.mark.parametrize('state_dict_type', ['local', 'sharded']) -@pytest.mark.parametrize('autoresume', [True, False]) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='requires PyTorch 1.13 or higher') -def test_fsdp_partitioned_state_dict_load(world_size, tmp_path: pathlib.Path, state_dict_type: str, autoresume: bool): - if autoresume: - run_name = 'my-autoresume-run' - else: - run_name = None +def test_fsdp_partitioned_state_dict_load(world_size, tmp_path: pathlib.Path, state_dict_type: str): save_folder = tmp_path save_filename = 'rank{rank}.pt' trainer1 = get_trainer(save_folder=str(save_folder), save_filename=save_filename, - fsdp_state_dict_type=state_dict_type, - run_name=run_name, - autoresume=autoresume) + fsdp_state_dict_type=state_dict_type) trainer1.fit() state_dict_from_trainer1 = trainer1.state.state_dict() trainer1.close() load_path = str(save_folder / pathlib.Path('rank{rank}.pt')) - trainer2 = get_trainer( - save_folder=str(save_folder), - save_filename=save_filename, - fsdp_state_dict_type=state_dict_type, - load_path=load_path, - autoresume=autoresume, - run_name=run_name, - ) + trainer2 = get_trainer(fsdp_state_dict_type=state_dict_type, load_path=load_path) state_dict_from_trainer2 = trainer2.state.state_dict() # Compare saved state and loaded state for both ranks. diff --git a/tests/utils/test_autolog_hparams.py b/tests/utils/test_autolog_hparams.py index 73f54b330c7..667b6fa81ac 100644 --- a/tests/utils/test_autolog_hparams.py +++ b/tests/utils/test_autolog_hparams.py @@ -169,7 +169,7 @@ def test_extract_hparams_trainer(): 'deterministic_mode': False, # Distributed Training - 'dist_timeout': 600.0, + 'dist_timeout': 1800.0, 'ddp_sync_strategy': None, # Profiling