Skip to content

Commit

Permalink
[test] fix regnet test for ddp+mixed_precision
Browse files Browse the repository at this point in the history
- need AMP context in FSDP
- workaround different between ddp & fsdp when bias=True
- fixed a bug in input data generation that caused different ranks have
  the same data with wrong iteration count.
- added TODO for need a better loss and grad_scaler and reduced
  iters so there is no nan.
- added a (disabled) debugging code
  • Loading branch information
min-xu-ai committed Mar 30, 2021
1 parent dfb12d1 commit acc95a3
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions tests/nn/data_parallel/test_fsdp_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
_world_size = 2
_iterations = 100
_iterations = 5

# TODO (Min): test inplace relu in the future.
_relu_inplace = False
Expand Down Expand Up @@ -104,7 +104,9 @@ def __init__(self):

# head
self.head = Sequential(
Sequential(Linear(16, 16), ReLU(), Linear(16, 8),), # projection_head
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True.
# so, we use bias=False for now in the projection_head.
Sequential(Linear(16, 16, bias=False), ReLU(), Linear(16, 8, bias=False),), # projection_head
Linear(8, 15, bias=False), # prototypes0
)

Expand Down Expand Up @@ -137,7 +139,7 @@ def ddp_ref():
world_size = _world_size
iterations = _iterations
print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}")
inputs = [[]] * world_size
inputs = [[] for i in range(world_size)]
for rank in range(world_size):
for i in range(iterations):
inputs[rank].append(torch.rand(2, 2, 2, 2))
Expand Down Expand Up @@ -211,6 +213,9 @@ def _test_func(
if fsdp_config:
ddp = False
assert isinstance(fsdp_config, dict), str(fsdp_config)
if fsdp_config["mixed_precision"]:
# To match DDP in AMP -O1, we need fp32 reduce scatter.
fsdp_config["fp32_reduce_scatter"] = True

model = Model()
model.load_state_dict(state_before)
Expand Down Expand Up @@ -242,8 +247,11 @@ def _test_func(
if ddp and ddp_mixed_precision:
in_data = in_data.half()
context = torch.cuda.amp.autocast(enabled=True)
if not ddp and fsdp_config["mixed_precision"] == True:
context = torch.cuda.amp.autocast(enabled=True)
with context:
out = model(in_data)
# TODO (Min): this loss is causing nan after ~10 iters, need a real loss and grad scaler.
loss = out.sum()
loss.backward()
optim.step()
Expand All @@ -261,6 +269,13 @@ def _test_func(
# Move tensors to CPU to compare numerics.
for k, v in fsdp_state.items():
fsdp_state[k] = v.cpu()
# Enable for debugging.
if False and rank==0:
def dump(d):
for k, v in d.items():
print(k, v)
dump(state_after)
dump(fsdp_state)
assert objects_are_equal(state_after, fsdp_state, raise_exception=True)

teardown()
Expand Down

0 comments on commit acc95a3

Please sign in to comment.