From 926a83c2ff010866477e43648965f7ee2c267965 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 10 Jan 2024 01:48:35 -0800 Subject: [PATCH 1/5] first try --- composer/trainer/trainer.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 0655be3004..6c82131ac9 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -69,6 +69,16 @@ Scheduler = Union[ComposerScheduler, PyTorchScheduler] +# Context manager that takes in a seed as an argument +# It first saves the rng state, then sets the seed, then yields, then restores the rng state +@contextlib.contextmanager +def _seed_context(seed: int): + rng_state = reproducibility.get_rng_state() + reproducibility.seed_all(seed) + yield + reproducibility.load_rng_state(rng_state) + + def _raise_missing_argument_exception(arg_name: str): raise ValueError((f'{arg_name} is a required argument and must be specified when constructing the ' f'{Trainer.__name__} or when calling {Trainer.__name__}.{Trainer.fit.__name__}(). ' @@ -1293,7 +1303,8 @@ def __init__( # FSDP wrap if not using monolith checkpoint on rank 0 only if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only: - prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + with _seed_context(self.state.rank_zero_seed): + prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) # Configure Deepspeed if self.state.deepspeed_config is not None: @@ -1440,7 +1451,8 @@ def __init__( # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if # load_fsdp_monolith_rank0_only=True but no checkpoint was loaded. if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only: - prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + with _seed_context(self.state.rank_zero_seed): + prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) self.engine.run_event(Event.AFTER_LOAD) From 32a7b75bd7399bd859c794b610e875eaddd573b4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 10 Jan 2024 16:40:41 +0000 Subject: [PATCH 2/5] add context --- composer/core/state.py | 7 ++++--- composer/trainer/trainer.py | 14 ++------------ composer/utils/reproducibility.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index be92a799ce..e5a410a465 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -27,7 +27,7 @@ from composer.core.serializable import Serializable from composer.core.time import Time, Timestamp, TimeUnit from composer.devices import Device -from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed +from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed, reproducibility from composer.utils.misc import using_torch_2 if TYPE_CHECKING: @@ -1264,8 +1264,9 @@ def load_model_and_optimizer_state( assert self.fsdp_config is not None log.info('Wrapping model with FSDP after loading model_state.') from composer.trainer.dist_strategy import prepare_fsdp_module - prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, - self.auto_microbatching) + with reproducibility.seed_context(self.state.rank_zero_seed): + prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, + self.auto_microbatching) log.debug('Finished wrapping model with FSDP.') # Legacy optimizer state load must happen after FSDP monolith diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 6c82131ac9..8497e1f41b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -69,16 +69,6 @@ Scheduler = Union[ComposerScheduler, PyTorchScheduler] -# Context manager that takes in a seed as an argument -# It first saves the rng state, then sets the seed, then yields, then restores the rng state -@contextlib.contextmanager -def _seed_context(seed: int): - rng_state = reproducibility.get_rng_state() - reproducibility.seed_all(seed) - yield - reproducibility.load_rng_state(rng_state) - - def _raise_missing_argument_exception(arg_name: str): raise ValueError((f'{arg_name} is a required argument and must be specified when constructing the ' f'{Trainer.__name__} or when calling {Trainer.__name__}.{Trainer.fit.__name__}(). ' @@ -1303,7 +1293,7 @@ def __init__( # FSDP wrap if not using monolith checkpoint on rank 0 only if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only: - with _seed_context(self.state.rank_zero_seed): + with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) # Configure Deepspeed @@ -1451,7 +1441,7 @@ def __init__( # FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if # load_fsdp_monolith_rank0_only=True but no checkpoint was loaded. if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only: - with _seed_context(self.state.rank_zero_seed): + with reproducibility.seed_context(self.state.rank_zero_seed): prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) self.engine.run_event(Event.AFTER_LOAD) diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index 0895b530d9..bcf820170e 100644 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -54,6 +54,7 @@ import time import warnings from typing import Any, Dict, List +from contextlib import contextmanager import numpy as np import torch @@ -62,6 +63,7 @@ from composer.utils import dist __all__ = [ + 'seed_context', 'configure_deterministic_mode', 'get_random_seed', 'seed_all', @@ -76,6 +78,15 @@ MAX_SEED = 2**32 - 1 +@contextmanager +def seed_context(seed: int): + # Context manager to store rng_state and reseed for duration of context + rng_state = get_rng_state() + seed_all(seed) + yield + load_rng_state(rng_state) + + def configure_deterministic_mode(): """Configure PyTorch deterministic mode. From d71381d6e3c0deca4850acce7f5ed49fd95d99cf Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 10 Jan 2024 16:43:11 +0000 Subject: [PATCH 3/5] lint --- composer/core/state.py | 2 +- composer/utils/reproducibility.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index e5a410a465..bea2220196 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -1264,7 +1264,7 @@ def load_model_and_optimizer_state( assert self.fsdp_config is not None log.info('Wrapping model with FSDP after loading model_state.') from composer.trainer.dist_strategy import prepare_fsdp_module - with reproducibility.seed_context(self.state.rank_zero_seed): + with reproducibility.seed_context(self.rank_zero_seed): prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, self.auto_microbatching) log.debug('Finished wrapping model with FSDP.') diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index bcf820170e..bf708ea5f1 100644 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -80,7 +80,7 @@ @contextmanager def seed_context(seed: int): - # Context manager to store rng_state and reseed for duration of context + """Context manager to store rng_state and reseed for duration of context.""" rng_state = get_rng_state() seed_all(seed) yield From 485495a680e45de09c9e3c4ba2a7f0f8b335fc5b Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 10 Jan 2024 11:43:47 -0500 Subject: [PATCH 4/5] more lint --- composer/core/state.py | 3 ++- composer/utils/reproducibility.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index bea2220196..6a8a2038b6 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -27,7 +27,8 @@ from composer.core.serializable import Serializable from composer.core.time import Time, Timestamp, TimeUnit from composer.devices import Device -from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed, reproducibility +from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed, + reproducibility) from composer.utils.misc import using_torch_2 if TYPE_CHECKING: diff --git a/composer/utils/reproducibility.py b/composer/utils/reproducibility.py index bf708ea5f1..8ae92a5a83 100644 --- a/composer/utils/reproducibility.py +++ b/composer/utils/reproducibility.py @@ -53,8 +53,8 @@ import textwrap import time import warnings -from typing import Any, Dict, List from contextlib import contextmanager +from typing import Any, Dict, List import numpy as np import torch From e24e5b0ba9c1ab9a56d37f31417827c6299cade6 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 10 Jan 2024 12:55:26 -0500 Subject: [PATCH 5/5] remove comment --- composer/trainer/dist_strategy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 66e5f7a509..96d476a453 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -243,11 +243,6 @@ def prepare_fsdp_module( 'gpu and some ranks are on meta. Either keep all ranks on the same ' "device or set fsdp_config['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.') - # Comment out while we debug deadlock - # if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): - # raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' - # 'fsdp_config["sync_module_states"] = True or different replicas will ' - # 'have different weights.') # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we