From 8b59267b2e3f213b62b3aaa19e6dace0d5f10a26 Mon Sep 17 00:00:00 2001 From: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Date: Thu, 18 Mar 2021 09:15:37 -0700 Subject: [PATCH] [feat] FSDP: add auto_wrap_bn (#531) * [feat] FSDP: add auto_wrap_bn - add an utility function to handle wrapping of BN * changelog --- CHANGELOG.md | 3 +- fairscale/nn/data_parallel/__init__.py | 2 +- .../fully_sharded_data_parallel.py | 44 +++++++++++++++++++ tests/nn/data_parallel/test_fsdp_regnet.py | 26 +++++------ 4 files changed, 60 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef77b3359..0167280d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372)) -- FSDP: enabling pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527)) +- FSDP: enabled pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527)) +- FSDP: added auto\_wrap\_bn utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531)) ### Fixed - OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510)) diff --git a/fairscale/nn/data_parallel/__init__.py b/fairscale/nn/data_parallel/__init__.py index 8f43fea7f..ab812834a 100644 --- a/fairscale/nn/data_parallel/__init__.py +++ b/fairscale/nn/data_parallel/__init__.py @@ -3,5 +3,5 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState +from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn from .sharded_ddp import ShardedDataParallel diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 3fee1eebd..b08bc5df8 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from fairscale.nn.misc import FlattenParamsWrapper +from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.optim.utils import calc_grad_norm from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group @@ -1337,3 +1338,46 @@ def _pre_load_state_dict_hook( state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any ) -> None: replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.") + + +######################################################################################## +# Below are APIs used together with FSDP, but not directly part of FSDP. +######################################################################################## + + +def auto_wrap_bn(module: nn.Module) -> nn.Module: + """ + Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert + to sync BN is used and the outer FSDP is flattening. + + We put BN in is own full precision, unflatten, single GPU group FSDP. Note, SyncBNs still have + a group size == world_size. The input and output for BN are still FP16 in mixed precision mode. + See ``keep_batchnorm_fp32`` here: https://nvidia.github.io/apex/amp.html + + This needs to be done at each rank, like models being wrapped by FSDP at each rank. + + Args: + module (nn.Module): + The model (or part of the model) in which BN to be pre-wrapped. + + Returns: + Processed module, where BNs are wrapped with a special FSDP instance. + """ + + def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool: + is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm) + if recurse: + return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore + else: + return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore + + my_rank = dist.get_rank() + fsdp_config = { + "wrapper_cls": FullyShardedDataParallel, + "process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group. + "mixed_precision": False, # Keep the weights in FP32. + "flatten_parameters": False, # Do not flatten. + } + + with enable_wrap(wrap_bn_only_policy, **fsdp_config): + return auto_wrap(module) diff --git a/tests/nn/data_parallel/test_fsdp_regnet.py b/tests/nn/data_parallel/test_fsdp_regnet.py index e04c96620..4109e6066 100644 --- a/tests/nn/data_parallel/test_fsdp_regnet.py +++ b/tests/nn/data_parallel/test_fsdp_regnet.py @@ -9,17 +9,17 @@ """ Test FSDP with regnet-like model. """ +import random import tempfile import pytest import torch -import torch.distributed as dist import torch.multiprocessing as mp from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm from torch.optim import SGD from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP -from fairscale.nn.data_parallel import TrainingState +from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version @@ -35,16 +35,7 @@ def __init__(self): # TODO (Min): for now, we just test pytorch sync_bn here. # this will grow into regnet; testing apex sync_bn, etc. self.conv = Conv2d(2, 2, (1, 1)) - # Put BN in is own FP32, unflatten, single GPU group FSDP. - # Note, SyncBNs still have a group size == world_size. - # The input and output for BN are still FP16. See ``keep_batchnorm_fp32`` - # here: https://nvidia.github.io/apex/amp.html - self.bn = FSDP( - BatchNorm2d(2), - mixed_precision=False, - process_group=dist.new_group(ranks=[rank]), - flatten_parameters=False, - ) + self.bn = BatchNorm2d(2) def forward(self, x): x = self.conv(x) @@ -54,7 +45,16 @@ def forward(self, x): # TODO (Min): check DDP equivalency. model = Model() - model = SyncBatchNorm.convert_sync_batchnorm(model) + # Note, different rank may wrap in different order due to different random + # seeds. But results should be the same. + if random.randint(0, 1) == 0: + print("auto_wrap_bn, then convert_sync_batchnorm") + model = auto_wrap_bn(model) + model = SyncBatchNorm.convert_sync_batchnorm(model) + else: + print("convert_sync_batchnorm, then auto_wrap_bn") + model = SyncBatchNorm.convert_sync_batchnorm(model) + model = auto_wrap_bn(model) model = FSDP(model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1)