-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed precision regnet test #556
[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed precision regnet test #556
Changes from 1 commit
dfb12d1
acc95a3
301d60a
8c04065
4208a15
927a6a0
c34bc70
27d8333
73e19ea
eb85041
cda3a8e
5fa4c5e
b523da4
621ec86
a5e7be8
8818326
ce9095b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -445,7 +445,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O | |
# Add dict key to the assertion error. | ||
msg = e.args[0] | ||
new_msg = f"For dict key '{dict_key}': {msg}" | ||
raise AssertionError(new_msg) | ||
raise AssertionError(new_msg) from None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this makes assertion output better, it doesn't print the first exception and then print "while handling the first exception the second exception is raised". |
||
else: | ||
raise e | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm | ||
from torch.nn import AdaptiveAvgPool2d, BatchNorm2d, Conv2d, Linear, Module, ReLU, Sequential, Sigmoid, SyncBatchNorm | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.optim import SGD | ||
|
||
|
@@ -32,18 +32,85 @@ | |
torch_version, | ||
) | ||
|
||
# Const test params. | ||
# Reduce iterations to 1 for debugging. | ||
# Change world_size to 8 on beefy machines for better test coverage. | ||
_world_size = 2 | ||
_iterations = 100 | ||
|
||
# TODO (Min): test inplace relu in the future. | ||
_relu_inplace = False | ||
|
||
# TODO (Min): test apex BN when available in the future. | ||
try: | ||
import apex | ||
|
||
apex_bn_converter = apex.parallel.convert_syncbn_model | ||
except ImportError: | ||
apex_bn_converter = None | ||
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm # type: ignore | ||
_bn_converter = pytorch_bn_converter | ||
_single_rank_pg = False | ||
|
||
|
||
class ResBlock(Module): | ||
"""Conv block in regnet with residual connection.""" | ||
|
||
def __init__(self, width_in, width_out): | ||
super().__init__() | ||
self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False) | ||
self.bn = BatchNorm2d(width_out) | ||
self.f = Sequential( | ||
Sequential( # block a | ||
Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace), | ||
), | ||
Sequential( # block b | ||
Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False), | ||
BatchNorm2d(width_out), | ||
ReLU(_relu_inplace), | ||
), | ||
Sequential( # block se | ||
AdaptiveAvgPool2d((1, 1)), | ||
Sequential( | ||
Conv2d(width_out, 2, (1, 1), (1, 1), bias=False), | ||
ReLU(_relu_inplace), | ||
Conv2d(2, width_out, (1, 1), (1, 1), bias=False), | ||
Sigmoid(), | ||
), | ||
), | ||
Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False), # block c | ||
BatchNorm2d(width_out), # final_bn | ||
) | ||
self.relu = ReLU() | ||
self.need_fsdp_wrap = True | ||
|
||
def forward(self, x): | ||
x = self.bn(self.proj(x)) + self.f(x) | ||
return self.relu(x) | ||
|
||
|
||
class Model(Module): | ||
"""SSL model with trunk and head.""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
# TODO (Min): for now, we just test pytorch sync_bn here. | ||
# this will grow into regnet; testing apex sync_bn, etc. | ||
self.conv = Conv2d(2, 2, (1, 1)) | ||
self.bn = BatchNorm2d(2) | ||
# trunk | ||
self.trunk = Sequential() | ||
self.trunk.need_fsdp_wrap = True # Set a flag for later wrapping. | ||
stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace),) | ||
any_stage_block1_0 = ResBlock(4, 8) | ||
self.trunk.add_module("stem", stem) | ||
self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0)) | ||
|
||
# head | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe remove since the variable is named as head? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. will do. |
||
self.head = Sequential( | ||
Sequential(Linear(16, 16), ReLU(), Linear(16, 8),), # projection_head | ||
Linear(8, 15, bias=False), # prototypes0 | ||
) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.bn(x) | ||
x = self.trunk(x).reshape(-1) | ||
x = self.head(x) | ||
return x | ||
|
||
|
||
|
@@ -67,8 +134,9 @@ def ddp_ref(): | |
state_before = model.state_dict() | ||
|
||
# Get reference inputs per rank. | ||
world_size = 2 | ||
iterations = 100 | ||
world_size = _world_size | ||
iterations = _iterations | ||
print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}") | ||
inputs = [[]] * world_size | ||
for rank in range(world_size): | ||
for i in range(iterations): | ||
|
@@ -150,19 +218,22 @@ def _test_func( | |
|
||
if ddp: | ||
model = SyncBatchNorm.convert_sync_batchnorm(model) | ||
model = DDP(model, device_ids=[rank]) | ||
model = DDP(model, device_ids=[rank], broadcast_buffers=True) | ||
else: | ||
# Note, different rank may wrap in different order due to different random | ||
# seeds. But results should be the same. | ||
if random.randint(0, 1) == 0: | ||
print("auto_wrap_bn, then convert_sync_batchnorm") | ||
model = auto_wrap_bn(model) | ||
model = SyncBatchNorm.convert_sync_batchnorm(model) | ||
model = auto_wrap_bn(model, _single_rank_pg) | ||
model = _bn_converter(model) | ||
else: | ||
print("convert_sync_batchnorm, then auto_wrap_bn") | ||
model = SyncBatchNorm.convert_sync_batchnorm(model) | ||
model = auto_wrap_bn(model) | ||
model = _bn_converter(model) | ||
model = auto_wrap_bn(model, _single_rank_pg) | ||
model = FSDP(model, **fsdp_config).cuda() | ||
# Print the model for verification. | ||
if rank == 0: | ||
print(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this useful in tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it is useful in verifying the way FSDP wrapped each module. It is a bit long output, but normally not shown unless there is a failure. |
||
optim = SGD(model.parameters(), lr=0.1) | ||
|
||
for in_data in inputs[rank]: | ||
|
@@ -215,7 +286,7 @@ def test1(temp_files, ddp_ref, precision, flatten): | |
fsdp_config["mixed_precision"] = precision == "mixed" | ||
fsdp_config["flatten_parameters"] = flatten == "flatten" | ||
|
||
world_size = 2 | ||
world_size = _world_size | ||
mp.spawn( | ||
_test_func, | ||
args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the main change. For now, I put this in an option and may need it later with apex BN or performance. This optional seems to be buggy in FSDP. Need more debugging. Enabling it also has impact on state_dict saving/loading as well as the optimizer state.