Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] save memory by using bucket buffer only in backward #633

Merged
merged 8 commits into from
Apr 28, 2021
Merged

Conversation

min-xu-ai
Copy link
Contributor

@min-xu-ai min-xu-ai commented Apr 27, 2021

  • this fixes bug FSDP memory utilization issue  #627
  • added documentation to clarify the buffer's cost and speed/memory
    tradeoff
  • added setup/teardown calls so that the buffer is only allocated
    during the backward pass, saving more memory for forward and stepping
    so that they can be used for things like activations.
  • added a unit test that assert the memory is in range.

Comparing with DDP:

  1. buffer size scales with # of FSDP not model size
  2. buffer is only allocated during backward
  3. buffer is used for small tensors only to reduce overhead
  4. overlapping of compute-reduction is very different

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • Did you update the changelog? (if needed)

What does this PR do?

Fixes # (issue).

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

- this fixes bug #627
- added documentation to clarify the buffer's cost and speed/memory
  tradeoff
- added setup/teardown calls so that the buffer is only allocated
  during the backward pass, saving more memory for forward and stepping
  so that they can be used for things like activations.
- added a unit test that assert the memory is in range.

Comparing with DDP:

  1. buffer size scales with # of FSDP not model size
  2. buffer is only allocated during backward
  3. buffer is used for small tensors only to reduce overhead
  4. overlapping of compute-reduction is very different
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2021
@min-xu-ai min-xu-ai marked this pull request as draft April 27, 2021 01:21
@min-xu-ai min-xu-ai mentioned this pull request Apr 27, 2021
4 tasks
@min-xu-ai min-xu-ai marked this pull request as ready for review April 27, 2021 01:49
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## NEXT - TBD
### Added
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux, one thing tricky here is that when I was modifying the changelog file, the PR number wasn't available. I had to come back and edit this file again. :-)

controls the bucket size in MegaBytes (MB). Buckets are sub-divided
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
be more efficient for small tensors.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I try to document completely here. Please let me know if this is too much and/or inaccurate.

@@ -37,6 +38,24 @@ def flush(self) -> None:
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0])

def setup(self) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This two functions enables some memory saving outside of backward pass.

tests/nn/data_parallel/test_fsdp_memory.py Outdated Show resolved Hide resolved
Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, and nice test!

fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
tests/nn/data_parallel/test_fsdp_memory.py Outdated Show resolved Hide resolved
@min-xu-ai min-xu-ai merged commit a559403 into master Apr 28, 2021
@min-xu-ai min-xu-ai deleted the min/mem branch April 28, 2021 00:46
Comment on lines +169 to +172
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see, this is a good point!

@sshleifer
Copy link
Contributor

sshleifer commented Apr 30, 2021

Just read this. Nice catch and very educational docs!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants