-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat][fix] ShardedDDP deferred init #558
Conversation
30ea553
to
1d4ddbe
Compare
ah nice! this is a nice feature :) just for my own understanding, if we defer the model transfer for very large models will this come with any noticeable speed benefit by sharding on CPU -> then move to GPU? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah nice! this is a nice feature :) just for my own understanding, if we defer the model transfer for very large models will this come with any noticeable speed benefit by sharding on CPU -> then move to GPU?
good question, but I don't think it would change too much, the idea here is more that it gives a little more flexibility for this step (move your model to device anytime, as long as it's ready for the step that's good)
def forward(self, *inputs: Any, **kwargs: Any) -> Any: | ||
""" | ||
Module forward pass, handles any DDP-specific work in the background. Primes the | ||
backward pass for gradient reduction to the proper ranks. | ||
""" | ||
|
||
# Optionally check whether the trainable parameters have changed | ||
# Deferred initialization, or change detection | ||
needs_setup = len(self._grad_hooks) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the real PR change, deffer the initialization so that we use the device at the first forward time, assumed correct
@@ -478,6 +483,10 @@ def _setup_backward_hooks(self) -> None: | |||
This makes the gradient reduction automatic whenever there's a backward pass | |||
""" | |||
|
|||
# Detach possible pre-existing hooks | |||
while len(self._grad_hooks) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recommended by the doc, detach the old hooks before attaching the new ones, I've never seen anything wrong with the current version but better follow the rules..
2f4ce67
to
7ca6597
Compare
linting..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) | ||
ddp_model = ShardedDataParallel(model, optimizer) | ||
|
||
# Move the model to another device post-construction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@min-xu-ai this is the sanity check to make sure that this feature stays in
Before submitting
What does this PR do?
Another fix on https://fb.workplace.com/groups/778897096280386/permalink/915100272660067/ (+incoming lightning fix, but this one in fairscale could be useful to others). Make sure that it's possible to change the model device after it has been wrapped by ShardedDDP (without this PR the buckets would be on the wrong device. Not completely a "bug" I presume because the .to() call on ShardedDDP had an assert and would have reported that this was not supported)
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
cc @ananthsub @SeanNaren
Did you have fun?
Make sure you had fun coding 🙃