Skip to content

Commit

Permalink
[feat] FSDP: add auto_wrap_bn (#531)
Browse files Browse the repository at this point in the history
* [feat] FSDP: add auto_wrap_bn

- add an utility function to handle wrapping of BN

* changelog
  • Loading branch information
min-xu-ai authored Mar 18, 2021
1 parent 2fc1f6d commit 8b59267
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372))
- FSDP: enabling pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527))
- FSDP: enabled pytorch SyncBN (no asserting) ([#527](https://github.com/facebookresearch/fairscale/issues/527))
- FSDP: added auto\_wrap\_bn utility function ([#531](https://github.com/facebookresearch/fairscale/pull/531))

### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState
from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn
from .sharded_ddp import ShardedDataParallel
44 changes: 44 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import calc_grad_norm
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
Expand Down Expand Up @@ -1337,3 +1338,46 @@ def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")


########################################################################################
# Below are APIs used together with FSDP, but not directly part of FSDP.
########################################################################################


def auto_wrap_bn(module: nn.Module) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening.
We put BN in is own full precision, unflatten, single GPU group FSDP. Note, SyncBNs still have
a group size == world_size. The input and output for BN are still FP16 in mixed precision mode.
See ``keep_batchnorm_fp32`` here: https://nvidia.github.io/apex/amp.html
This needs to be done at each rank, like models being wrapped by FSDP at each rank.
Args:
module (nn.Module):
The model (or part of the model) in which BN to be pre-wrapped.
Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
"""

def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) -> bool:
is_bn = isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
if recurse:
return not isinstance(module, tuple(default_auto_wrap_policy.FORCE_LEAF_MODULES)) # type: ignore
else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore

my_rank = dist.get_rank()
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group.
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
}

with enable_wrap(wrap_bn_only_policy, **fsdp_config):
return auto_wrap(module)
26 changes: 13 additions & 13 deletions tests/nn/data_parallel/test_fsdp_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@

""" Test FSDP with regnet-like model. """

import random
import tempfile

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.optim import SGD

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version


Expand All @@ -35,16 +35,7 @@ def __init__(self):
# TODO (Min): for now, we just test pytorch sync_bn here.
# this will grow into regnet; testing apex sync_bn, etc.
self.conv = Conv2d(2, 2, (1, 1))
# Put BN in is own FP32, unflatten, single GPU group FSDP.
# Note, SyncBNs still have a group size == world_size.
# The input and output for BN are still FP16. See ``keep_batchnorm_fp32``
# here: https://nvidia.github.io/apex/amp.html
self.bn = FSDP(
BatchNorm2d(2),
mixed_precision=False,
process_group=dist.new_group(ranks=[rank]),
flatten_parameters=False,
)
self.bn = BatchNorm2d(2)

def forward(self, x):
x = self.conv(x)
Expand All @@ -54,7 +45,16 @@ def forward(self, x):
# TODO (Min): check DDP equivalency.

model = Model()
model = SyncBatchNorm.convert_sync_batchnorm(model)
# Note, different rank may wrap in different order due to different random
# seeds. But results should be the same.
if random.randint(0, 1) == 0:
print("auto_wrap_bn, then convert_sync_batchnorm")
model = auto_wrap_bn(model)
model = SyncBatchNorm.convert_sync_batchnorm(model)
else:
print("convert_sync_batchnorm, then auto_wrap_bn")
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = auto_wrap_bn(model)
model = FSDP(model, **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)

Expand Down

0 comments on commit 8b59267

Please sign in to comment.