Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix seed for FSDP wrap #2833

Merged
merged 5 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from composer.core.serializable import Serializable
from composer.core.time import Time, Timestamp, TimeUnit
from composer.devices import Device
from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed
from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed,
reproducibility)
from composer.utils.misc import using_torch_2

if TYPE_CHECKING:
Expand Down Expand Up @@ -1264,8 +1265,9 @@ def load_model_and_optimizer_state(
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
with reproducibility.seed_context(self.rank_zero_seed):
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
log.debug('Finished wrapping model with FSDP.')

# Legacy optimizer state load must happen after FSDP monolith
Expand Down
5 changes: 0 additions & 5 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,6 @@ def prepare_fsdp_module(
'gpu and some ranks are on meta. Either keep all ranks on the same '
"device or set fsdp_config['sync_module_states'] = True. Otherwise, "
'some weights may be randomly initialized when loading a checkpoint.')
# Comment out while we debug deadlock
# if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'):
# raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires '
# 'fsdp_config["sync_module_states"] = True or different replicas will '
# 'have different weights.')

# Check if other ranks OOMed after forward/backward pass when using auto microbatching. This
# may happen when close to memory limit or with uneven memory usage across ranks. Since we
Expand Down
6 changes: 4 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,8 @@ def __init__(

# FSDP wrap if not using monolith checkpoint on rank 0 only
if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only:
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)

# Configure Deepspeed
if self.state.deepspeed_config is not None:
Expand Down Expand Up @@ -1440,7 +1441,8 @@ def __init__(
# FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if
# load_fsdp_monolith_rank0_only=True but no checkpoint was loaded.
if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only:
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)

self.engine.run_event(Event.AFTER_LOAD)

Expand Down
11 changes: 11 additions & 0 deletions composer/utils/reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import textwrap
import time
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List

import numpy as np
Expand All @@ -62,6 +63,7 @@
from composer.utils import dist

__all__ = [
'seed_context',
'configure_deterministic_mode',
'get_random_seed',
'seed_all',
Expand All @@ -76,6 +78,15 @@
MAX_SEED = 2**32 - 1


@contextmanager
def seed_context(seed: int):
"""Context manager to store rng_state and reseed for duration of context."""
rng_state = get_rng_state()
seed_all(seed)
yield
load_rng_state(rng_state)


def configure_deterministic_mode():
"""Configure PyTorch deterministic mode.

Expand Down
Loading