Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix]: let FSDP handle model with multiple forward pass and checkpoint #621

Merged
merged 20 commits into from
Apr 26, 2021

Conversation

min-xu-ai
Copy link
Contributor

@min-xu-ai min-xu-ai commented Apr 21, 2021

This PR adds support of multiple passes of a module in a forward pass and activation checkpoint on the module at the same time.

Right now, the tested cases are:

  1. FSDP(ckpt(module))
  2. FSDP(ckpt(module, FSDP(BN), module))

Next commit will test and enable FSDP(ckpt(), ..., ckpt()) type of cases.

The fix is to add an option to checkpoint_wrapper so that it can keep a counter on the checkpointed modules so that FSDP can check the counter to determine if the backward callback fired from the module is the last one in the bigger, outer backward pass.

A new API to override the pre/post gradient divide factors. It is currently used in the test to make sure numerical match with DDP.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes part of #617 .

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2021
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
if with_checkpoint:
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@QuentinDuval, once this PR is merged, you need to use mintain_forward_counter=True in vissl when forward-multiple-times and checkpoint are both used.

@min-xu-ai min-xu-ai marked this pull request as ready for review April 23, 2021 22:59
Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Note that I don't have context on why pre_backward_hook_has_run used to be a list.

if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a thing? Which gradients need gradients!?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! I think it came from @myleott originally. I would love to know more too. :-)

@min-xu-ai
Copy link
Contributor Author

LGTM. Note that I don't have context on why pre_backward_hook_has_run used to be a list.

Thanks for the review @sshleifer! The pre_backward_hook_has_run being a list isn't originally from me, but I think I know why it is that way.

Say for this example:

def func():
   x = 0
   def nest():
       x = 1

The second assignment to x will be to a new local variable inside the nest function. However, if you turn it into:

def func():
   x = [ 0 ]
   def nest():
      x[0] = 1

Then, the second assignment will be a read of the outer x, which ensures the right variable is updated. It is a very clever trick to turn a write into a read first so that python knows which variable you are referring too. Using python's keyword nonlocal would get the same effect, like:

def func():
  x = 0
  def nest():
     nonlocal x
     x = 1

if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled

pre_backward_hook_has_run = [False]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originally, the variable is local to _register_pre_backward_hooks function and it only prevents callbacks registered within this function. When there are multiple calls to this function, this local flag won't prevent multiple callbacks. Moving this flag to self solves this problem.

@min-xu-ai min-xu-ai merged commit a1612d7 into master Apr 26, 2021
@min-xu-ai min-xu-ai deleted the min/bug_617 branch April 26, 2021 16:48
@min-xu-ai
Copy link
Contributor Author

I have merged this for now @myleott. Definitely happy to address more comments separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants