-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] save memory by using bucket buffer only in backward #633
Conversation
- this fixes bug #627 - added documentation to clarify the buffer's cost and speed/memory tradeoff - added setup/teardown calls so that the buffer is only allocated during the backward pass, saving more memory for forward and stepping so that they can be used for things like activations. - added a unit test that assert the memory is in range. Comparing with DDP: 1. buffer size scales with # of FSDP not model size 2. buffer is only allocated during backward 3. buffer is used for small tensors only to reduce overhead 4. overlapping of compute-reduction is very different
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | |||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | |||
|
|||
## NEXT - TBD | |||
### Added | |||
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux, one thing tricky here is that when I was modifying the changelog file, the PR number wasn't available. I had to come back and edit this file again. :-)
controls the bucket size in MegaBytes (MB). Buckets are sub-divided | ||
based on world_size, so the max shard size is roughly | ||
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing. | ||
be more efficient for small tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I try to document completely here. Please let me know if this is too much and/or inaccurate.
@@ -37,6 +38,24 @@ def flush(self) -> None: | |||
self.callbacks.clear() | |||
self.output_shard = torch.zeros_like(self.data[0]) | |||
|
|||
def setup(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This two functions enables some memory saving outside of backward pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, and nice test!
# TODO (Min): the `group` used here in the key is the object hash, not the content | ||
# hash. That means if FSDP instances are initialized with different process groups, | ||
# even when the group members are in fact the same, we end up creating different | ||
# buckets here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I see, this is a good point!
Just read this. Nice catch and very educational docs! |
tradeoff
during the backward pass, saving more memory for forward and stepping
so that they can be used for things like activations.
Comparing with DDP:
Before submitting
What does this PR do?
Fixes # (issue).
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃