diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 16862cd9c..903414169 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1517,7 +1517,7 @@ def _pre_load_state_dict_hook( ######################################################################################## -def auto_wrap_bn(module: nn.Module) -> nn.Module: +def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module: """ Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert to sync BN is used and the outer FSDP is flattening. @@ -1531,6 +1531,9 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module: Args: module (nn.Module): The model (or part of the model) in which BN to be pre-wrapped. + single_rank_pg (bool): + If true, put BNs in a single-rank process group. Default False. + This might be needed for Apex sync BN support. Still under construction. Returns: Processed module, where BNs are wrapped with a special FSDP instance. @@ -1543,10 +1546,15 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) else: return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore - my_rank = dist.get_rank() + pg = None + if single_rank_pg: + # No sharding with this single member group. + my_rank = dist.get_rank() + pg = dist.new_group(ranks=[my_rank]) + fsdp_config = { "wrapper_cls": FullyShardedDataParallel, - "process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group. + "process_group": pg, "mixed_precision": False, # Keep the weights in FP32. "flatten_parameters": False, # Do not flatten. } diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index 2cec5821f..d83570603 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -32,6 +32,7 @@ import multiprocessing import os import random +import subprocess import sys import tempfile from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -116,6 +117,30 @@ def torch_version() -> Tuple[int, ...]: return tuple(int(n) for n in numbering) +_smi_ver = None + + +def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]: + if compiled: + numbering = torch.version.cuda.split(".")[:2] + else: + global _smi_ver + if _smi_ver is None: + + def get_smi_ver() -> str: + """Get CUDA version from nvidia-smi""" + for line in subprocess.check_output("nvidia-smi".split()).decode("utf-8").split("\n"): + if "CUDA Version" in line: + res = line.split()[8] + assert res.startswith("10.") or res.startswith("11."), res + return res + assert False + + _smi_ver = get_smi_ver() + numbering = _smi_ver.split(".")[:2] + return tuple(int(n) for n in numbering) + + def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool: """ Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated @@ -445,7 +470,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O # Add dict key to the assertion error. msg = e.args[0] new_msg = f"For dict key '{dict_key}': {msg}" - raise AssertionError(new_msg) + raise AssertionError(new_msg) from None else: raise e else: diff --git a/tests/ci_test_list_2.txt b/tests/ci_test_list_2.txt index 7aaf33400..024015bd9 100644 --- a/tests/ci_test_list_2.txt +++ b/tests/ci_test_list_2.txt @@ -33,4 +33,8 @@ tests/nn/pipe/test_deferred_batch_norm.py tests/nn/pipe/test_dependency.py tests/nn/pipe/test_stream.py tests/experimental/nn/test_multiprocess_pipe.py +tests/nn/moe/test_moe_layer.py +tests/nn/moe/test_top2gating.py +tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py +tests/experimental/nn/test_offload.py tests/nn/data_parallel/test_fsdp_apply.py diff --git a/tests/ci_test_list_3.txt b/tests/ci_test_list_3.txt index 3010b6b30..b4f6367d8 100644 --- a/tests/ci_test_list_3.txt +++ b/tests/ci_test_list_3.txt @@ -1,10 +1,10 @@ +tests/nn/data_parallel/test_fsdp_regnet.py tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/data_parallel/test_fsdp_no_sync.py tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_multiple_forward.py -tests/nn/data_parallel/test_fsdp_regnet.py tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py @@ -17,12 +17,8 @@ tests/nn/pipe/skip/test_portal.py tests/nn/pipe/skip/test_tracker.py tests/nn/pipe/skip/test_inspect_skip_layout.py tests/nn/pipe/test_checkpoint_ddp.py -tests/nn/moe/test_moe_layer.py -tests/nn/moe/test_top2gating.py tests/optim/test_single_node_adascale.py tests/optim/test_adam.py tests/optim/test_oss.py tests/optim/test_oss_adascale.py tests/optim/test_ddp_adascale.py -tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py -tests/experimental/nn/test_offload.py diff --git a/tests/nn/data_parallel/test_fsdp_regnet.py b/tests/nn/data_parallel/test_fsdp_regnet.py index 31a8d1d5f..8f600cd7d 100644 --- a/tests/nn/data_parallel/test_fsdp_regnet.py +++ b/tests/nn/data_parallel/test_fsdp_regnet.py @@ -15,13 +15,26 @@ import pytest import torch +from torch.cuda.amp import GradScaler import torch.multiprocessing as mp -from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm +from torch.nn import ( + AdaptiveAvgPool2d, + BatchNorm2d, + Conv2d, + CrossEntropyLoss, + Linear, + Module, + ReLU, + Sequential, + Sigmoid, + SyncBatchNorm, +) from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn +from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.utils.testing import ( dist_init, objects_are_equal, @@ -29,21 +42,96 @@ skip_if_single_gpu, state_dict_norm, teardown, + torch_cuda_version, torch_version, ) +# Const test params. +# Reduce iterations to 1 for debugging. +# Change world_size to 8 on beefy machines for better test coverage. +_world_size = 2 +_iterations = 5 + +# Cover different ReLU flavor. This will cause DDP and FSDP models to have +# different ReLUs since they will different random flags. +_relu_inplace = True +if random.randint(0, 1) == 0: + _relu_inplace = False + +# TODO (Min): test apex BN when available in the future. +try: + import apex + + apex_bn_converter = apex.parallel.convert_syncbn_model +except ImportError: + apex_bn_converter = None +pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm # type: ignore +_bn_converter = pytorch_bn_converter +_single_rank_pg = False + + +class ResBlock(Module): + """Conv block in regnet with residual connection.""" + + def __init__(self, width_in, width_out): + super().__init__() + self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False) + self.bn = BatchNorm2d(width_out) + self.f = Sequential( + Sequential( # block a + Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace), + ), + Sequential( # block b + Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False), + BatchNorm2d(width_out), + ReLU(_relu_inplace), + ), + Sequential( # block se + AdaptiveAvgPool2d((1, 1)), + Sequential( + Conv2d(width_out, 2, (1, 1), (1, 1), bias=False), + ReLU(_relu_inplace), + Conv2d(2, width_out, (1, 1), (1, 1), bias=False), + Sigmoid(), + ), + ), + Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False), # block c + BatchNorm2d(width_out), # final_bn + ) + self.relu = ReLU() + self.need_fsdp_wrap = True + + def forward(self, x): + x = self.bn(self.proj(x)) + self.f(x) + return self.relu(x) + class Model(Module): + """SSL model with trunk and head.""" + def __init__(self): super().__init__() - # TODO (Min): for now, we just test pytorch sync_bn here. - # this will grow into regnet; testing apex sync_bn, etc. - self.conv = Conv2d(2, 2, (1, 1)) - self.bn = BatchNorm2d(2) + print(f"Using relu inplace: {_relu_inplace}") + + self.trunk = Sequential() + self.trunk.need_fsdp_wrap = True # Set a flag for later wrapping. + stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace)) + any_stage_block1_0 = ResBlock(4, 8) + self.trunk.add_module("stem", stem) + self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0)) + + self.head = Sequential( + # TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True. + # so, we use bias=False for now in the projection_head. + # The Conv2d layers above does not use bias in regnet, but even if they use + # bias, FSDP and DDP seem to agree on how it is computed. + Sequential(Linear(16, 16, bias=False), ReLU(), Linear(16, 8, bias=False),), # projection_head + Linear(8, 15, bias=False), # prototypes0 + ) def forward(self, x): - x = self.conv(x) - x = self.bn(x) + x = self.trunk(x).reshape(-1) + x = self.head(x) return x @@ -67,9 +155,10 @@ def ddp_ref(): state_before = model.state_dict() # Get reference inputs per rank. - world_size = 2 - iterations = 100 - inputs = [[]] * world_size + world_size = _world_size + iterations = _iterations + print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}") + inputs = [[] for i in range(world_size)] for rank in range(world_size): for i in range(iterations): inputs[rank].append(torch.rand(2, 2, 2, 2)) @@ -86,6 +175,7 @@ def ddp_ref(): args=( world_size, fsdp_config, + None, precision == "mixed", temp_file_name, unused, @@ -128,6 +218,7 @@ def _test_func( rank, world_size, fsdp_config, + fsdp_wrap_bn, ddp_mixed_precision, tempfile_name, unused, @@ -143,27 +234,51 @@ def _test_func( if fsdp_config: ddp = False assert isinstance(fsdp_config, dict), str(fsdp_config) + if fsdp_config["mixed_precision"]: + # To match DDP in AMP -O1, we need fp32 reduce scatter. + fsdp_config["fp32_reduce_scatter"] = True model = Model() model.load_state_dict(state_before) model = model.cuda() + class DummyScaler: + def scale(self, loss): + return loss + + def step(self, optim): + optim.step() + + def update(self): + pass + + scaler = DummyScaler() if ddp: model = SyncBatchNorm.convert_sync_batchnorm(model) - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], broadcast_buffers=True) + if ddp_mixed_precision: + scaler = GradScaler() else: # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: - print("auto_wrap_bn, then convert_sync_batchnorm") - model = auto_wrap_bn(model) - model = SyncBatchNorm.convert_sync_batchnorm(model) + print(f"auto_wrap_bn {fsdp_wrap_bn}, then convert_sync_batchnorm") + if fsdp_wrap_bn: + model = auto_wrap_bn(model, _single_rank_pg) + model = _bn_converter(model) else: - print("convert_sync_batchnorm, then auto_wrap_bn") - model = SyncBatchNorm.convert_sync_batchnorm(model) - model = auto_wrap_bn(model) + print(f"convert_sync_batchnorm, then auto_wrap_bn {fsdp_wrap_bn}") + model = _bn_converter(model) + if fsdp_wrap_bn: + model = auto_wrap_bn(model, _single_rank_pg) model = FSDP(model, **fsdp_config).cuda() + if fsdp_config["mixed_precision"]: + scaler = ShardedGradScaler() + # Print the model for verification. + if rank == 0: + print(model) optim = SGD(model.parameters(), lr=0.1) + loss_func = CrossEntropyLoss() for in_data in inputs[rank]: in_data = in_data.cuda() @@ -171,11 +286,15 @@ def _test_func( if ddp and ddp_mixed_precision: in_data = in_data.half() context = torch.cuda.amp.autocast(enabled=True) + if not ddp and fsdp_config["mixed_precision"]: + context = torch.cuda.amp.autocast(enabled=True) with context: out = model(in_data) - loss = out.sum() - loss.backward() - optim.step() + fake_label = torch.zeros(1, dtype=torch.long).cuda() + loss = loss_func(out.unsqueeze(0), fake_label) + scaler.scale(loss).backward() + scaler.step(optim) + scaler.update() optim.zero_grad() if ddp: @@ -190,6 +309,15 @@ def _test_func( # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() + # Change False to True to enable this when you want to debug the mismatch. + if False and rank == 0: + + def dump(d): + for k, v in d.items(): + print(k, v) + + dump(state_after) + dump(fsdp_state) assert objects_are_equal(state_after, fsdp_state, raise_exception=True) teardown() @@ -215,10 +343,32 @@ def test1(temp_files, ddp_ref, precision, flatten): fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["flatten_parameters"] = flatten == "flatten" - world_size = 2 + if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0): + pytest.skip("Only CUDA 11 is supported with AMP equivalency") + + # Wrap BN half of the time in full precision mode. + wrap_bn = True + if random.randint(0, 1) == 0: + wrap_bn = False + # Always wrap BN in mixed precision mode. + if fsdp_config["mixed_precision"]: + wrap_bn = True + + world_size = _world_size mp.spawn( _test_func, - args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after), + args=( + world_size, + fsdp_config, + wrap_bn, + None, + temp_files[0], + temp_files[1], + state_before, + inputs, + None, + state_after, + ), nprocs=world_size, join=True, )