Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat][fix] ShardedDDP deferred init #558

Merged
merged 4 commits into from
Mar 30, 2021
Merged

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Mar 30, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Another fix on https://fb.workplace.com/groups/778897096280386/permalink/915100272660067/ (+incoming lightning fix, but this one in fairscale could be useful to others). Make sure that it's possible to change the model device after it has been wrapped by ShardedDDP (without this PR the buckets would be on the wrong device. Not completely a "bug" I presume because the .to() call on ShardedDDP had an assert and would have reported that this was not supported)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
cc @ananthsub @SeanNaren

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 30, 2021
@blefaudeux blefaudeux marked this pull request as draft March 30, 2021 03:46
@blefaudeux blefaudeux force-pushed the shardedddp_deferred_init branch from 30ea553 to 1d4ddbe Compare March 30, 2021 03:51
@blefaudeux blefaudeux marked this pull request as ready for review March 30, 2021 05:09
@SeanNaren
Copy link

ah nice! this is a nice feature :) just for my own understanding, if we defer the model transfer for very large models will this come with any noticeable speed benefit by sharding on CPU -> then move to GPU?

Copy link
Contributor Author

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah nice! this is a nice feature :) just for my own understanding, if we defer the model transfer for very large models will this come with any noticeable speed benefit by sharding on CPU -> then move to GPU?

good question, but I don't think it would change too much, the idea here is more that it gives a little more flexibility for this step (move your model to device anytime, as long as it's ready for the step that's good)

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks.
"""

# Optionally check whether the trainable parameters have changed
# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the real PR change, deffer the initialization so that we use the device at the first forward time, assumed correct

@@ -478,6 +483,10 @@ def _setup_backward_hooks(self) -> None:
This makes the gradient reduction automatic whenever there's a backward pass
"""

# Detach possible pre-existing hooks
while len(self._grad_hooks) > 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommended by the doc, detach the old hooks before attaching the new ones, I've never seen anything wrong with the current version but better follow the rules..

@blefaudeux blefaudeux force-pushed the shardedddp_deferred_init branch from 2f4ce67 to 7ca6597 Compare March 30, 2021 16:27
Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)

# Move the model to another device post-construction
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@min-xu-ai this is the sanity check to make sure that this feature stays in

@blefaudeux blefaudeux merged commit daa1bad into master Mar 30, 2021
@blefaudeux blefaudeux deleted the shardedddp_deferred_init branch March 30, 2021 20:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants