-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix fsdp+pp+te WPS decreasing issue #1139
Fix fsdp+pp+te WPS decreasing issue #1139
Conversation
@@ -1650,6 +1654,9 @@ def _register_post_backward_hooks(self) -> None: | |||
assert p_tmp.grad_fn is not None | |||
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. | |||
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) | |||
if not hasattr(p, "_shard_bwd_hooks"): | |||
p._shard_bwd_hooks = [] | |||
p._shard_bwd_hooks.append((grad_acc, handle)) | |||
p._shard_bwd_hook = (grad_acc, handle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be deleted? See P841842878 CC @awgu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented this line following the style in this file.
for module_name, module in self.named_modules(): | ||
if isinstance(module, FullyShardedDataParallel): | ||
module._module_fqn = module_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These only needed for debugging (when _FSDP_DEBUG is set in P841842878)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted.
if self.fp32_reduce_scatter: | ||
# Cast grad to FP32. | ||
param.grad.data = param.grad.data.float() | ||
orig_grad_data = param.grad.data.float() | ||
else: | ||
orig_grad_data = param.grad.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep param.grad.data = param.grad.data.float() so something like
if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.float()
orig_grad_data = param.grad.data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
@@ -1710,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |||
|
|||
# Switch to FP32 shard after backward. | |||
self._use_fp32_param_shard([param]) | |||
if self.mixed_precision and self.fp32_reduce_scatter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry missed this comment before. Wonder if we can have a separate commit for main_grad related changes from the changes for wps decrease fix (P841842878).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this PR do?
Fixes WPS decreasing with steps, due to the FSDP handle issues.
Credits all go to @awgu @vivien-chu . Merge this into ngoyal_changes_for_pp_fp8 due to the failure observed in test_te.py when reproducing https://github.com/fairinternal/xlformers/pull/1360 w/ @jiecaoyu @jspark1105
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.