-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed precision regnet test #556
Conversation
- beefed up unit test with regnet-like model - found that single-rank process group is causing problem - disabled it to enable convergence tests on the vissl side - use `raise e from None` to get a better assertion output in testing.py.
@@ -1543,10 +1546,15 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) | |||
else: | |||
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore | |||
|
|||
my_rank = dist.get_rank() | |||
pg = None | |||
if single_rank_pg: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the main change. For now, I put this in an option and may need it later with apex BN or performance. This optional seems to be buggy in FSDP. Need more debugging. Enabling it also has impact on state_dict saving/loading as well as the optimizer state.
@@ -445,7 +445,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O | |||
# Add dict key to the assertion error. | |||
msg = e.args[0] | |||
new_msg = f"For dict key '{dict_key}': {msg}" | |||
raise AssertionError(new_msg) | |||
raise AssertionError(new_msg) from None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes assertion output better, it doesn't print the first exception and then print "while handling the first exception the second exception is raised".
wow, this is passing locally on my machine but failing on CI. converting to draft for more debugging. |
this is such an awesome work @min-xu-ai . In case helpful , just adding a little more context on how we use syncBN: in SyncBatchNorm (Apex or Pytorch), the process groups are indeed very important. People frequently change process groups (for example, for 128 gpus, we can use process groups of size 64 which leads to 2groups, or we can use process groups of size 8 leading to 16 groups etc.). |
Thanks, @prigoyal . This is helpful info. I think we want to support apex sync bn as well but it currently has more issues than pytorch BN. The process group here is actually NOT the group of the pytorch sync BN group. The group here is FSDP sharding group. I used to have each rank's BN in their own single rank groups, which avoids sharding them. But that seems to have bugs. So I am undoing that here. |
- need AMP context in FSDP - workaround different between ddp & fsdp when bias=True - fixed a bug in input data generation that caused different ranks have the same data with wrong iteration count. - added TODO for need a better loss and grad_scaler and reduced iters so there is no nan. - added a (disabled) debugging code
I think the test is going to pass and this change should be ready to review after that. High level updates:
|
_smi_ver = None | ||
|
||
|
||
def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a new function to either get torch compiled cuda version or Nvidia-smi reported cuda version.
@@ -33,4 +33,8 @@ tests/nn/pipe/test_deferred_batch_norm.py | |||
tests/nn/pipe/test_dependency.py | |||
tests/nn/pipe/test_stream.py | |||
tests/experimental/nn/test_multiprocess_pipe.py | |||
tests/nn/moe/test_moe_layer.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move a few tests around to balance the CI runtime.
|
||
model = Model() | ||
model.load_state_dict(state_before) | ||
model = model.cuda() | ||
|
||
class DummyScaler: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anj-s I used this DummyScaler when scaling is not needed in full precision mode. Could be a useful thing to do in general to unify full and mixed precision training loop.
@@ -215,7 +334,10 @@ def test1(temp_files, ddp_ref, precision, flatten): | |||
fsdp_config["mixed_precision"] = precision == "mixed" | |||
fsdp_config["flatten_parameters"] = flatten == "flatten" | |||
|
|||
world_size = 2 | |||
if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msbaines It is great that CI has a CUDA 11 VM. Otherwise, this test would have no place to run. It seems that we should gradually move everything to CUDA 11.
really awesome work @min-xu-ai and amazing insights!! :)
this is indeed interesting. In general, not having bias can introduce regularization issue and exploding training. I would encourage that we debug this.
Just double checking: we use the context for the autocasting only? I think this is awesome! pytorchAMP is mandated also by SDP so we can make a consistent switch.
interesting: this can slow down speed seems like. Curious question: why do we require fp32 grad reduction? if this something we require for all layers or simply because of BN (BN params should be all kept in fp32).
interesting. This is essentially plain BN and not syncBN if we have to use single rank BN. Definitely sounds like bug somewhere. Overall, seems like BN is introducing a lot of weirdness. |
self.trunk.add_module("stem", stem) | ||
self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0)) | ||
|
||
# head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe remove since the variable is named as head?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. will do.
# head | ||
self.head = Sequential( | ||
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True. | ||
# so, we use bias=False for now in the projection_head. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the second line need a tab?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it is not a tab, I was aligning it but I should do a 4-space indent. will fix.
self.head = Sequential( | ||
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True. | ||
# so, we use bias=False for now in the projection_head. | ||
# The Conv2d layers above does not used bias in regnet, but even if they use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/used/use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
|
||
# head | ||
self.head = Sequential( | ||
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the context with the comments! very helpful to grok :)
scaler = ShardedGradScaler() | ||
# Print the model for verification. | ||
if rank == 0: | ||
print(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this useful in tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it is useful in verifying the way FSDP wrapped each module. It is a bit long output, but normally not shown unless there is a failure.
|
||
for in_data in inputs[rank]: | ||
in_data = in_data.cuda() | ||
context = contextlib.suppress() | ||
if ddp and ddp_mixed_precision: | ||
in_data = in_data.half() | ||
context = torch.cuda.amp.autocast(enabled=True) | ||
if not ddp and fsdp_config["mixed_precision"]: | ||
context = torch.cuda.amp.autocast(enabled=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, when will this be true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is when we use FSDP in mixed precision mode. So we need to turn on pytorch AMP.
@@ -190,6 +310,15 @@ def _test_func( | |||
# Move tensors to CPU to compare numerics. | |||
for k, v in fsdp_state.items(): | |||
fsdp_state[k] = v.cpu() | |||
# Enable for debugging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Should we remove this for long running tests? I guess you are still debugging the issue so it makes sense to have it. only concern here is that it adds to the execution time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is disabled with "if False". I will clarify it.
raise e from None
to get a better assertion outputin testing.py.
Before submitting
What does this PR do?
Fixes # (issue).
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃