Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed precision regnet test #556

Merged
merged 17 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ def _pre_load_state_dict_hook(
########################################################################################


def auto_wrap_bn(module: nn.Module) -> nn.Module:
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening.
Expand All @@ -1531,6 +1531,9 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module:
Args:
module (nn.Module):
The model (or part of the model) in which BN to be pre-wrapped.
single_rank_pg (bool):
If true, put BNs in a single-rank process group. Default False.
This might be needed for Apex sync BN support. Still under construction.

Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
Expand All @@ -1543,10 +1546,15 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int)
else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore

my_rank = dist.get_rank()
pg = None
if single_rank_pg:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the main change. For now, I put this in an option and may need it later with apex BN or performance. This optional seems to be buggy in FSDP. Need more debugging. Enabling it also has impact on state_dict saving/loading as well as the optimizer state.

# No sharding with this single member group.
my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank])

fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group.
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
}
Expand Down
2 changes: 1 addition & 1 deletion fairscale/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
# Add dict key to the assertion error.
msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg)
raise AssertionError(new_msg) from None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes assertion output better, it doesn't print the first exception and then print "while handling the first exception the second exception is raised".

else:
raise e
else:
Expand Down
101 changes: 86 additions & 15 deletions tests/nn/data_parallel/test_fsdp_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.nn import AdaptiveAvgPool2d, BatchNorm2d, Conv2d, Linear, Module, ReLU, Sequential, Sigmoid, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD

Expand All @@ -32,18 +32,85 @@
torch_version,
)

# Const test params.
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
_world_size = 2
_iterations = 100

# TODO (Min): test inplace relu in the future.
_relu_inplace = False

# TODO (Min): test apex BN when available in the future.
try:
import apex

apex_bn_converter = apex.parallel.convert_syncbn_model
except ImportError:
apex_bn_converter = None
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm # type: ignore
_bn_converter = pytorch_bn_converter
_single_rank_pg = False


class ResBlock(Module):
"""Conv block in regnet with residual connection."""

def __init__(self, width_in, width_out):
super().__init__()
self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False)
self.bn = BatchNorm2d(width_out)
self.f = Sequential(
Sequential( # block a
Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace),
),
Sequential( # block b
Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False),
BatchNorm2d(width_out),
ReLU(_relu_inplace),
),
Sequential( # block se
AdaptiveAvgPool2d((1, 1)),
Sequential(
Conv2d(width_out, 2, (1, 1), (1, 1), bias=False),
ReLU(_relu_inplace),
Conv2d(2, width_out, (1, 1), (1, 1), bias=False),
Sigmoid(),
),
),
Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False), # block c
BatchNorm2d(width_out), # final_bn
)
self.relu = ReLU()
self.need_fsdp_wrap = True

def forward(self, x):
x = self.bn(self.proj(x)) + self.f(x)
return self.relu(x)


class Model(Module):
"""SSL model with trunk and head."""

def __init__(self):
super().__init__()
# TODO (Min): for now, we just test pytorch sync_bn here.
# this will grow into regnet; testing apex sync_bn, etc.
self.conv = Conv2d(2, 2, (1, 1))
self.bn = BatchNorm2d(2)
# trunk
self.trunk = Sequential()
self.trunk.need_fsdp_wrap = True # Set a flag for later wrapping.
stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace),)
any_stage_block1_0 = ResBlock(4, 8)
self.trunk.add_module("stem", stem)
self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0))

# head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe remove since the variable is named as head?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. will do.

self.head = Sequential(
Sequential(Linear(16, 16), ReLU(), Linear(16, 8),), # projection_head
Linear(8, 15, bias=False), # prototypes0
)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.trunk(x).reshape(-1)
x = self.head(x)
return x


Expand All @@ -67,8 +134,9 @@ def ddp_ref():
state_before = model.state_dict()

# Get reference inputs per rank.
world_size = 2
iterations = 100
world_size = _world_size
iterations = _iterations
print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}")
inputs = [[]] * world_size
for rank in range(world_size):
for i in range(iterations):
Expand Down Expand Up @@ -150,19 +218,22 @@ def _test_func(

if ddp:
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[rank], broadcast_buffers=True)
else:
# Note, different rank may wrap in different order due to different random
# seeds. But results should be the same.
if random.randint(0, 1) == 0:
print("auto_wrap_bn, then convert_sync_batchnorm")
model = auto_wrap_bn(model)
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = auto_wrap_bn(model, _single_rank_pg)
model = _bn_converter(model)
else:
print("convert_sync_batchnorm, then auto_wrap_bn")
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = auto_wrap_bn(model)
model = _bn_converter(model)
model = auto_wrap_bn(model, _single_rank_pg)
model = FSDP(model, **fsdp_config).cuda()
# Print the model for verification.
if rank == 0:
print(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this useful in tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is useful in verifying the way FSDP wrapped each module. It is a bit long output, but normally not shown unless there is a failure.

optim = SGD(model.parameters(), lr=0.1)

for in_data in inputs[rank]:
Expand Down Expand Up @@ -215,7 +286,7 @@ def test1(temp_files, ddp_ref, precision, flatten):
fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten"

world_size = 2
world_size = _world_size
mp.spawn(
_test_func,
args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after),
Expand Down