Skip to content

Commit

Permalink
Revert "Checkpoints Simplified (#2041)" (#2056)
Browse files Browse the repository at this point in the history
This reverts commit 25c9a67.
  • Loading branch information
dakinggg authored Mar 10, 2023
1 parent 48ec0f5 commit 0ef85dc
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 75 deletions.
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class Trainer:
.. seealso:: :mod:`composer.utils.reproducibility` for more details on reproducibility.
dist_timeout (float, optional): Timeout, in seconds, for initializing the distributed process group.
(default: ``600.0``)
(default: ``1800.0``)
ddp_sync_strategy (str | DDPSyncStrategy, optional): The strategy to use for synchronizing gradients.
Leave unset to let the trainer auto-configure this. See :class:`.DDPSyncStrategy`
for more details.
Expand Down Expand Up @@ -907,7 +907,7 @@ def __init__(
deterministic_mode: bool = False,

# Distributed Training
dist_timeout: float = 600.0,
dist_timeout: float = 1800.0,
ddp_sync_strategy: Optional[Union[str, DDPSyncStrategy]] = None,

# Profiling
Expand Down
165 changes: 99 additions & 66 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@
_DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name.


def _format_path_with_rank_zero(path: str) -> str:
"""Formats ``path`` with the rank zero values."""
return path.format(
rank=0,
local_rank=0,
node_rank=0,
)


def _format_path_with_current_rank(path: str) -> str:
"""Formats ``path`` formatted with the current rank values."""
return path.format(
Expand Down Expand Up @@ -184,27 +193,31 @@ def load_checkpoint(
"""
# download the checkpoint to the node-local folder
log.debug('Loading checkpoint at %s', path)
# Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node
tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None)
# Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node.
# If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder.
tempdir_ctx = (tempfile.TemporaryDirectory() if (state.fsdp_sharded_state_dict_enabled or
dist.get_local_rank() == 0) else contextlib.nullcontext(None))
with tempdir_ctx as tempdir:
try:
# Get path to temporary folder for node-local checkpoints created on rank 0
node_checkpoint_folder = _get_local_rank_zero_path(tempdir)
# Get the path to the proper checkpoint folder corresponding to the current rank's node.
# If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir.
node_checkpoint_folder = (tempdir if state.fsdp_sharded_state_dict_enabled else
_get_node_checkpoint_download_folder(tempdir))
assert node_checkpoint_folder is not None

composer_states_filepath, extracted_tar_checkpoint_folder = download_checkpoint(
composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint(
path=path,
node_checkpoint_folder=node_checkpoint_folder,
object_store=object_store,
progress_bar=progress_bar,
fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled,
deepspeed_checkpoint=is_model_deepspeed(state.model),
)
rng_state_dicts = _restore_checkpoint(
state,
logger,
composer_states_filepath,
extracted_tar_checkpoint_folder,
extracted_rank_n,
extracted_checkpoint_folder,
load_weights_only=load_weights_only,
strict_model_weights=strict_model_weights,
ignore_keys=ignore_keys,
Expand All @@ -220,7 +233,7 @@ def load_checkpoint(
return rng_state_dicts


def _get_local_rank_zero_path(path: Optional[str]) -> str:
def _get_node_checkpoint_download_folder(path: Optional[str]) -> str:
"""Broadcasts the ``path`` from the LOCAL rank zero to all LOCAL ranks."""
local_rank_zero = dist.get_local_world_size() * dist.get_node_rank()
paths = dist.all_gather_object(path)
Expand All @@ -235,77 +248,88 @@ def download_checkpoint(
object_store: Optional[Union[ObjectStore, LoggerDestination]],
progress_bar: bool,
fsdp_sharded_state_dict_enabled: bool = False,
deepspeed_checkpoint: bool = False,
) -> Tuple[str, Optional[str]]:
) -> Tuple[str, Optional[str], bool]:
"""Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``.
Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``).
Returns a tuple of (``composer_states_filepath``, ``extracted_checkpoint_folder``, ``extracted_rank_n``).
* The ``composer_states_filepath``, is the path to the composer states, which can be passed into
:meth:`torch.load`.
* The ``extracted_tar_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into
* The ``extracted_checkpoint_folder`` is the path to the checkpoint folder, which can be passed into
:meth:`deepspeed.DeepSpeedEngine.load_checkpoint`.
* The ``extracted_rank_n`` is a boolean flag indicating whether a tarball was extracted on global
rank greater than 0.
"""
log.debug('Downloading checkpoint to folder %s', node_checkpoint_folder)
# Files do not have extensions as it could be .tar, .pt, or something else
rank_zero_checkpoint_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint')
rank_n_checkpoint_filepath = os.path.join(node_checkpoint_folder, f'rank{dist.get_global_rank()}_checkpoint')
extracted_tar_checkpoint_folder = None

# DeepSpeed checkpoints must be tarballs
if deepspeed_checkpoint and not is_tar(path):
raise ValueError(f'Checkpoint at {path} is not a tarball, which is needed for DeepSpeed checkpoint')

# Determine where checkpoint will be after downloading
extracted_checkpoint_folder = None
extracted_rank_n = False
if is_tar(path):
# DeepSpeed checkpoints OR checkpoints of all formats with compression
extracted_tar_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint')
composer_states_filepath = os.path.join(extracted_tar_checkpoint_folder, _COMPOSER_STATES_FILENAME)
elif fsdp_sharded_state_dict_enabled:
# FSDP sharded state dict has a Composer state dict per rank
composer_states_filepath = rank_n_checkpoint_filepath
extracted_checkpoint_folder = os.path.join(node_checkpoint_folder, 'checkpoint')
composer_states_filepath = os.path.join(extracted_checkpoint_folder, _COMPOSER_STATES_FILENAME)
else:
# Vanilla Composer has a single Composer state dict on local rank zero
composer_states_filepath = os.path.join(node_checkpoint_folder, 'rank0_checkpoint')

local_path = _format_path_with_current_rank(path)
local_rank_zero_path = _get_local_rank_zero_path(local_path)
# it's not an archive; it's just the composer state dict
# and only rank zero has this file unless fsdp_sharded_state_dict_enabled then
# every rank has it's own file.
extracted_checkpoint_folder = None
composer_states_filepath = (rank_n_checkpoint_filepath
if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath)

try:
# Download on local rank 0 or if path is different from local rank 0
if dist.get_local_rank() == 0 or local_path != local_rank_zero_path:
get_file_succeeded = True
if ((fsdp_sharded_state_dict_enabled and dist.get_global_rank() == 0) or
(not fsdp_sharded_state_dict_enabled and dist.get_local_rank() == 0)):
# every NODE needs the GLOBAL rank zero checkpoint unless fsdp_sharded_state_dict_enabled.
path = _format_path_with_rank_zero(path)
get_file(destination=rank_zero_checkpoint_filepath,
path=path,
object_store=object_store,
progress_bar=progress_bar)
if extracted_checkpoint_folder is not None:
try:
with tarfile.open(rank_zero_checkpoint_filepath) as tarball:
tarball.extractall(extracted_checkpoint_folder)
except FileNotFoundError:
# Not re-raising the file-not-found error as that is irrelevant;
# the underlying issue is that the checkpoint file does not exist on the disk
# or could not be downloaded
raise RuntimeError(f'Checkpoint {path} does not exist')

if rank_zero_checkpoint_filepath != rank_n_checkpoint_filepath:
# every RANK needs ITS OWN checkpoint.
# But, the global rank zero is a special case -- these files are the same!
assert dist.get_global_rank() != 0, 'invariant violation'

try:
get_file(
destination=rank_n_checkpoint_filepath,
path=_format_path_with_current_rank(path),
object_store=object_store,
progress_bar=progress_bar,
)
except FileNotFoundError as e:
get_file_succeeded = False
# If the checkpoint is not found, raise the error only on local rank zero, which
# requires a checkpoint, or if FSDP sharded checkpoints or deepspeed are enabled,
# which require a checkpoint on all ranks
if dist.get_local_rank() == 0 or fsdp_sharded_state_dict_enabled or deepspeed_checkpoint:
raise e
# Otherwise, ignore error as standard checkpointing does not have rank-local checkpoints
get_file(destination=rank_n_checkpoint_filepath,
path=_format_path_with_current_rank(path),
object_store=object_store,
progress_bar=progress_bar)
except FileNotFoundError:
# Allowing not-found errors to be ignored as sometimes there won't be rank-local checkpoints
# (e.g. when not using deepspeed nor using fsdp sharded checkpoints)
pass

# Extract tarballs, which happens for DeepSpeed or compression
if get_file_succeeded and extracted_tar_checkpoint_folder is not None:
with tarfile.open(rank_n_checkpoint_filepath) as tarball:
tarball.extractall(extracted_tar_checkpoint_folder)
if extracted_checkpoint_folder is not None:
try:
# it's an archive and needs to be extracted
with tarfile.open(rank_n_checkpoint_filepath) as tarball:
tarball.extractall(extracted_checkpoint_folder)
extracted_rank_n = True
except FileNotFoundError:
# this will happen most of the time (i.e. whenever deepspeed
# is not being used) so not logging anything
pass

finally:
# Wait for all checkpoints on the node to finish downloading. First, we busy wait until
# file exists, ensuring synchronization intra-node. Next, we use `dist.barrier()` to
# sync across nodes. This is necessary to avoid timing out on the barrier when downloading
# large checkpoints, which may exceed the normal timeout. Now, a timeout is only
# encountered if the difference between checkpoint download times on the slowest and
# fastest nodes exceeds the timeout.
dist.local_rank_zero_download_and_wait(composer_states_filepath)
# Wait for all checkpoints on the node to finish downloading
# Putting the barrier in a finally so the rank will always block on the barrier,
# even if it has an exception.
# Any exception will be re-raised after the barrier passes. The launcher script
# will detect the process crash and terminate the other ranks
dist.barrier()

return composer_states_filepath, extracted_tar_checkpoint_folder
return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n


def _flatten_keys(obj: Any, paths: List[str], existing_path: str):
Expand Down Expand Up @@ -399,14 +423,16 @@ def _restore_checkpoint(
state: State,
logger: Logger,
composer_states_filepath: str,
extracted_tar_checkpoint_folder: Optional[str],
extracted_rank_n: bool,
extracted_checkpoint_folder: Optional[str],
load_weights_only: bool,
strict_model_weights: bool,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]],
exclude_algorithms: Optional[List[str]],
algorithm_passes: Optional[List[AlgorithmPass]],
) -> Optional[List[Dict[str, Any]]]:
"""Restore a checkpoint into ``state`` and returns the rng state dicts (if ``load_weights_only`` is False)."""
# Now, all ranks load the checkpoint that local rank zero downloaded
state_dict = safe_torch_load(composer_states_filepath)
if ignore_keys:
# Filter provided list of key paths
Expand All @@ -417,11 +443,15 @@ def _restore_checkpoint(
log.debug(f"Loaded checkpoint with keys {state_dict.keys()} and state keys {state_dict['state'].keys()}")

if is_model_deepspeed(state.model):
if extracted_tar_checkpoint_folder is None:
if extracted_checkpoint_folder is None:
raise RuntimeError('Deepspeed checkpoints require a tarball, not a weights file.')

global_rank = dist.get_global_rank()
if global_rank > 0 and not extracted_rank_n:
raise RuntimeError(f'Deepspeed checkpoint missing for rank {global_rank}')

load_path, _ = state.deepspeed_model.load_checkpoint(
extracted_tar_checkpoint_folder,
extracted_checkpoint_folder,
tag=_DEEPSPEED_TAG,
load_module_only=load_weights_only,
load_module_strict=strict_model_weights,
Expand Down Expand Up @@ -452,6 +482,7 @@ def save_checkpoint(
*,
weights_only: bool = False,
) -> Union[str, None]: # noqa: D103

log.debug('Saving checkpoint to %s', filename)

is_deepspeed = is_model_deepspeed(state.model)
Expand All @@ -468,15 +499,15 @@ def save_checkpoint(
if dirname:
os.makedirs(dirname, exist_ok=True)

# Only rank 0 saves the state_dict unless state.fsdp_sharded_state_dict_enabled=True
# only rank 0 saves the state_dict unless state.fsdp_sharded_state_dict_enabled=True.
if dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled:
with open(save_filename, 'wb') as f:
torch.save(state_dict, f)

if is_tar(save_filename):
_compress_file(save_filename, basename=_COMPOSER_STATES_FILENAME)

# All ranks save for deepspeed
# all ranks save for deepspeed
if is_deepspeed:
_save_deepspeed_model(state.deepspeed_model, save_filename)

Expand All @@ -485,7 +516,9 @@ def save_checkpoint(
if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled:
assert os.path.exists(save_filename), 'Expected file to have been saved.'
return save_filename
return None
else:
# no file saved
return None


def _compress_file(filename: str, basename: str):
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def is_initialized():
return dist.is_initialized()


def initialize_dist(device: Union[str, Device], timeout: float = 600.0):
def initialize_dist(device: Union[str, Device], timeout: float = 300.0):
"""Initialize the default PyTorch distributed process group.
This function assumes that the following environment variables are set:
Expand All @@ -374,7 +374,7 @@ def initialize_dist(device: Union[str, Device], timeout: float = 600.0):
interpreted. Either a string corresponding to a device (one of ``'cpu'``,
``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`.
timeout (float, optional): The timeout for operations executed against the process
group, expressed in seconds. (default: ``600.0``).
group, expressed in seconds. (default: ``300.0``).
"""
# If device is string, get corresponding composer.devices.Device object
device_obj = get_device(device)
Expand Down
8 changes: 4 additions & 4 deletions composer/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def export_for_inference(
# download checkpoint and load weights only
log.debug('Loading checkpoint at %s', load_path)
with tempfile.TemporaryDirectory() as tempdir:
composer_states_filepath, _ = download_checkpoint(path=load_path,
node_checkpoint_folder=tempdir,
object_store=load_object_store,
progress_bar=True)
composer_states_filepath, _, _ = download_checkpoint(path=load_path,
node_checkpoint_folder=tempdir,
object_store=load_object_store,
progress_bar=True)
state_dict = safe_torch_load(composer_states_filepath)
missing_keys, unexpected_keys = model.load_state_dict(state_dict['state']['model'], strict=load_strict)
if len(missing_keys) > 0:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_autolog_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_extract_hparams_trainer():
'deterministic_mode': False,

# Distributed Training
'dist_timeout': 600.0,
'dist_timeout': 1800.0,
'ddp_sync_strategy': None,

# Profiling
Expand Down

0 comments on commit 0ef85dc

Please sign in to comment.