-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix]: let FSDP handle model with multiple forward pass and checkpoint #621
Conversation
model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False) | ||
model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False) | ||
if with_checkpoint: | ||
model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@QuentinDuval, once this PR is merged, you need to use mintain_forward_counter=True in vissl when forward-multiple-times and checkpoint are both used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Note that I don't have context on why pre_backward_hook_has_run
used to be a list.
if param.grad.requires_grad: | ||
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") | ||
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a thing? Which gradients need gradients!?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! I think it came from @myleott originally. I would love to know more too. :-)
Thanks for the review @sshleifer! The Say for this example:
The second assignment to
Then, the second assignment will be a read of the outer
|
if not torch.is_grad_enabled(): | ||
return outputs # don't register hooks if grad isn't enabled | ||
|
||
pre_backward_hook_has_run = [False] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
originally, the variable is local to _register_pre_backward_hooks
function and it only prevents callbacks registered within this function. When there are multiple calls to this function, this local flag won't prevent multiple callbacks. Moving this flag to self
solves this problem.
I have merged this for now @myleott. Definitely happy to address more comments separately. |
This PR adds support of multiple passes of a module in a forward pass and activation checkpoint on the module at the same time.
Right now, the tested cases are:
Next commit will test and enable FSDP(ckpt(), ..., ckpt()) type of cases.
The fix is to add an option to
checkpoint_wrapper
so that it can keep a counter on the checkpointed modules so that FSDP can check the counter to determine if the backward callback fired from the module is the last one in the bigger, outer backward pass.A new API to override the pre/post gradient divide factors. It is currently used in the test to make sure numerical match with DDP.
Before submitting
What does this PR do?
Fixes part of #617 .
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃