Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fsdp+pp+te WPS decreasing issue #1139

Merged

Conversation

jianyuh
Copy link
Member

@jianyuh jianyuh commented Oct 1, 2023

What does this PR do?

Fixes WPS decreasing with steps, due to the FSDP handle issues.

Credits all go to @awgu @vivien-chu . Merge this into ngoyal_changes_for_pp_fp8 due to the failure observed in test_te.py when reproducing https://github.com/fairinternal/xlformers/pull/1360 w/ @jiecaoyu @jspark1105

-- Process 5 terminated with the following error:
Traceback (most recent call last):
  File "/home/jianyuhuang/Work/Github/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/home/jianyuhuang/Work/Github/xlformers/tests/test_te.py", line 470, in run_demo
    loss.backward()
  File "/home/jianyuhuang/Work/Github/pytorch/torch/_tensor.py", line 491, in backward
    torch.autograd.backward(
  File "/home/jianyuhuang/Work/Github/pytorch/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7fba963bbc80> returned NULL without setting an exception

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2023
@jianyuh jianyuh marked this pull request as ready for review October 1, 2023 23:23
@@ -1650,6 +1654,9 @@ def _register_post_backward_hooks(self) -> None:
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
if not hasattr(p, "_shard_bwd_hooks"):
p._shard_bwd_hooks = []
p._shard_bwd_hooks.append((grad_acc, handle))
p._shard_bwd_hook = (grad_acc, handle)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted? See P841842878 CC @awgu

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented this line following the style in this file.

Comment on lines 1224 to 1226
for module_name, module in self.named_modules():
if isinstance(module, FullyShardedDataParallel):
module._module_fqn = module_name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These only needed for debugging (when _FSDP_DEBUG is set in P841842878)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted.

Comment on lines 1736 to 1740
if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.float()
orig_grad_data = param.grad.data.float()
else:
orig_grad_data = param.grad.data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep param.grad.data = param.grad.data.float() so something like

            if self.fp32_reduce_scatter:
                # Cast grad to FP32.
                param.grad.data = param.grad.data.float()

            orig_grad_data = param.grad.data

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@@ -1710,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
if self.mixed_precision and self.fp32_reduce_scatter:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry missed this comment before. Wonder if we can have a separate commit for main_grad related changes from the changes for wps decrease fix (P841842878).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split into 2 PRs (#1139 and #1140)

@jianyuh jianyuh mentioned this pull request Oct 2, 2023
10 tasks
@jianyuh jianyuh merged commit 0db6e62 into ngoyal_changes_for_pp_fp8 Oct 2, 2023
1 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants