diff --git a/composer/core/state.py b/composer/core/state.py index be92a799ce..6a8a2038b6 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -27,7 +27,8 @@ from composer.core.serializable import Serializable from composer.core.time import Time, Timestamp, TimeUnit from composer.devices import Device -from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed +from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed, + reproducibility) from composer.utils.misc import using_torch_2 if TYPE_CHECKING: @@ -1264,8 +1265,9 @@ def load_model_and_optimizer_state( assert self.fsdp_config is not None log.info('Wrapping model with FSDP after loading model_state.') from composer.trainer.dist_strategy import prepare_fsdp_module - prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, - self.auto_microbatching) + with reproducibility.seed_context(self.rank_zero_seed): + prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, + self.auto_microbatching) log.debug('Finished wrapping model with FSDP.') # Legacy optimizer state load must happen after FSDP monolith diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 66e5f7a509..96d476a453 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -243,11 +243,6 @@ def prepare_fsdp_module( 'gpu and some ranks are on meta. Either keep all ranks on the same ' "device or set fsdp_config['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.') - # Comment out while we debug deadlock - # if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): - # raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' - # 'fsdp_config["sync_module_states"] = True or different replicas will ' - # 'have different weights.') # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 0655be3004..8497e1f41b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1293,7 +1293,8 @@ def __init__( # FSDP wrap if not using monolith checkpoint on rank 0 only if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only: - prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + with reproducibility.seed_context(self.state.rank_zero_seed): + prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) # Configure Deepspeed if self.state.deepspeed_config is not None: @@ -1440,7 +1441,8 @@ def __init__( # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if # load_fsdp_monolith_rank0_only=True but no checkpoint was loaded. if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only: - prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + with reproducibility.seed_context(self.state.rank_zero_seed): + prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) self.engine.run_event(Event.AFTER_LOAD) diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index 0895b530d9..8ae92a5a83 100644 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -53,6 +53,7 @@ import textwrap import time import warnings +from contextlib import contextmanager from typing import Any, Dict, List import numpy as np @@ -62,6 +63,7 @@ from composer.utils import dist __all__ = [ + 'seed_context', 'configure_deterministic_mode', 'get_random_seed', 'seed_all', @@ -76,6 +78,15 @@ MAX_SEED = 2**32 - 1 +@contextmanager +def seed_context(seed: int): + """Context manager to store rng_state and reseed for duration of context.""" + rng_state = get_rng_state() + seed_all(seed) + yield + load_rng_state(rng_state) + + def configure_deterministic_mode(): """Configure PyTorch deterministic mode.