From 5160a639608ce7e208039517ef581e62c301b480 Mon Sep 17 00:00:00 2001 From: Min Xu Date: Mon, 26 Apr 2021 17:18:22 -0700 Subject: [PATCH 1/7] [feat] save memory by using bucket buffer only in backward - this fixes bug #627 - added documentation to clarify the buffer's cost and speed/memory tradeoff - added setup/teardown calls so that the buffer is only allocated during the backward pass, saving more memory for forward and stepping so that they can be used for things like activations. - added a unit test that assert the memory is in range. Comparing with DDP: 1. buffer size scales with # of FSDP not model size 2. buffer is only allocated during backward 3. buffer is used for small tensors only to reduce overhead 4. overlapping of compute-reduction is very different --- CHANGELOG.md | 2 + .../fully_sharded_data_parallel.py | 30 +++- fairscale/utils/reduce_scatter_bucketer.py | 21 +++ tests/ci_test_list_1.txt | 1 + tests/nn/data_parallel/test_fsdp_memory.py | 162 ++++++++++++++++++ 5 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 tests/nn/data_parallel/test_fsdp_memory.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c1d98cf2..de6e59bef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## NEXT - TBD +### Added +- FSDP: better memory usage for reduce bucket ## [0.3.6] - 2021-04-26 ### Added diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 4fc3c27bb..6fd21fe78 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -155,10 +155,30 @@ class FullyShardedDataParallel(nn.Module): *``cpu_offload``*. bucket_cap_mb (int, Optional): FSDP will bucket parameters so that gradient reduction can - potentially overlap with backward computation. bucket_cap_mb - controls the bucket size in MegaBytes (MB). Buckets are sub-divided - based on world_size, so the max shard size is roughly - ``bucket_cap_mb / world_size``. Values <= 0 disable bucketing. + be more efficient for small tensors. + ``bucket_cap_mb`` controls the bucket size in MegaBytes (MB). Buckets + are sub-divided based on world_size, so the max shard size is roughly + ``bucket_cap_mb / world_size``. When nested FSDP is used, + each FSDP instance will have a separate set of buckets (1 per bucket + per tensor tuple). Large gradient tensors + are directly reduced without using the buffer. The buffer is there to reduce + communication overhead for small tensors. Overlapping with computation + happens due to use of a different CUDA stream than the computation CUDA + stream. The total memory overhead per buffer is around + ``bucket_cap_mb / world_size * (world_size + 1)``. + The buffers are allocated during the backward pass and freed at the end + of the backward pass to save more memory for other phases of the + training process. + Note, the memory vs. speed tradeoff of bucket size is very different + from that of the DDP engine. In DDP, the buffer size ``1MB + n*cap_mb``, + until n is big enough to cover the entire model size. The order + of which buffer is ready there is more rigid and DDP requires all + gradients to be computed in the backward. In FSDP, the buffer + does not change with model size (it scales per # of FSDP instances) + and gradient ready order matters little and we have a final flush + call that ensures everything is reduced and not all gradients need + to be upfront known. Overlapping with compute is done differently too. + Values <= 0 disable bucketing. Default: 25. compute_device (torch.device, Optional): device for computation. If not given and module params are on a CUDA @@ -1739,6 +1759,8 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) # **must** be False because BN's FSDP wrapper's pre-backward callback isn't called # within the checkpoint's outer backward when multiple forward passes are used. "reshard_after_forward": False, + # No bucketing or small bucketing should be enough for BNs. + "bucket_cap_mb": 0, } with enable_wrap(wrap_bn_only_policy, **fsdp_config): diff --git a/fairscale/utils/reduce_scatter_bucketer.py b/fairscale/utils/reduce_scatter_bucketer.py index b8e2eba54..14eaa8b81 100644 --- a/fairscale/utils/reduce_scatter_bucketer.py +++ b/fairscale/utils/reduce_scatter_bucketer.py @@ -21,6 +21,7 @@ def __init__(self, data: Tensor, group: ProcessGroup): self.output_shard = torch.zeros_like(data[0]) def flush(self) -> None: + """Flush content of the bucket.""" if self.offset == 0: assert len(self.callbacks) == 0 return @@ -37,6 +38,24 @@ def flush(self) -> None: self.callbacks.clear() self.output_shard = torch.zeros_like(self.data[0]) + def setup(self) -> None: + """ Setup the buffers if they are not allocated. + + Using ``setup`` and ``teardown``, we can ensure that the bucket + buffers are only allocated during the backward pass, hence saving more + memory to other parts of the training process, such as the forward pass + for activation memory. + """ + for tensor in [self.data, self.output_shard]: + if tensor.storage().size() == 0: + tensor.storage().resize_(tensor.size().numel()) + + def teardown(self) -> None: + """Tear down the bucket by freeing the memory""" + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + for tensor in [self.data, self.output_shard]: + tensor.storage().resize_(0) + class ReduceScatterBucketer: """ @@ -131,6 +150,7 @@ def flush(self) -> None: """Reduce-scatter any partial buckets.""" for bucket in self.buckets.values(): bucket.flush() + bucket.teardown() @functools.lru_cache() def _get_shard_size(self, element_size: int, num_shards: int) -> int: @@ -148,4 +168,5 @@ def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: shard_size = self._get_shard_size(tensor.element_size(), world_size) data = tensor.new_zeros((world_size, shard_size)) self.buckets[key] = Bucket(data, group) + self.buckets[key].setup() return self.buckets[key] diff --git a/tests/ci_test_list_1.txt b/tests/ci_test_list_1.txt index c273e66eb..ed40be164 100644 --- a/tests/ci_test_list_1.txt +++ b/tests/ci_test_list_1.txt @@ -1,3 +1,4 @@ +tests/nn/data_parallel/test_fsdp_memory.py tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py tests/nn/data_parallel/test_fsdp_multiple_wrapping.py tests/nn/data_parallel/test_fsdp_freezing_weights.py diff --git a/tests/nn/data_parallel/test_fsdp_memory.py b/tests/nn/data_parallel/test_fsdp_memory.py new file mode 100644 index 000000000..baae33ef0 --- /dev/null +++ b/tests/nn/data_parallel/test_fsdp_memory.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test FSDP with GPU memory usage. """ + +import gc + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +import torch.optim as optim + +from fairscale.nn import checkpoint_wrapper +from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP +from fairscale.nn.data_parallel import auto_wrap_bn +from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx + + +def get_global_group(): + """ + Singleton pytorch distributed group + Inspired by https://github.com/pytorch/fairseq + """ + if dist.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None + + +def to_fsdp(module): + return FSDP(module, process_group=get_global_group()) + + +def dump_all_tensors(rank): + """Use this for debugging""" + if rank != 0: + return + for obj in gc.get_objects(): + try: + # Only need to check parameter type objects if asked. + ttype = str(type(obj)) + if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)): + print(ttype, obj.shape, obj.dtype, obj.device, id(obj), obj.storage().size()) + except Exception as e: + pass + + +def get_cur_mem(rank, result, prefix): + """Collect memory allocated values in a result dict in MB""" + result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.blocks = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=5, padding=2), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=5, padding=2), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=5, padding=2), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.AdaptiveAvgPool2d(output_size=(1, 1)), + nn.Flatten(), + ) + self.head = nn.Linear(128, 10) + + def forward(self, x): + return self.head(self.blocks(self.stem(x))) + + +def create_model(with_fsdp, with_checkpoint): + model = Model() + if with_fsdp: + model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) + model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) + if with_checkpoint: + model.blocks = checkpoint_wrapper(model.blocks) + model.stem = to_fsdp(model.stem) + model.blocks = to_fsdp(model.blocks) + model.head = to_fsdp(model.head) + else: + if with_checkpoint: + model.blocks = checkpoint_wrapper(model.blocks) + return model + + +def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected): + torch.cuda.set_device(gpu_id) + + rank = gpu_id + result = dist_init(rank, world_size, filename, filename_rpc) + assert result, "Dist init failed" + + torch.manual_seed(0) + torch.backends.cudnn.deterministic = True + batch = torch.randn(size=(2, 3, 224, 224)).cuda() + + model = create_model(with_fsdp, with_checkpoint) + model = model.cuda() + if with_fsdp: + model = to_fsdp(model) + else: + model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) + + criterion = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=1e-4) + + results = {} + for iteration in range(5): + get_cur_mem(gpu_id, results, f"iter {iteration}: start") + + out = model(batch) + get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") + + out = sum(o.sum() for o in out[0]) + fake_loss = criterion(out, torch.tensor(0.0).cuda()) + get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") + + fake_loss.backward() + get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") + + optimizer.step() + get_cur_mem(gpu_id, results, f"iter {iteration}: after step") + + # It is important to use the loop below, not optimizer.zero_grad() to reclaim memory. + for p in model.parameters(): + p.grad = None + get_cur_mem(gpu_id, results, f"iter {iteration}: done") + + assert results == expected + + teardown() + + +@skip_if_single_gpu +@pytest.mark.parametrize("fsdp", ["ddp", "fsdp"]) +@pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"]) +def test_fsdp_memory(fsdp, ckpt): + expected = {("ddp", "no_ckpt"): {}, ("ddp", "ckpt"): {}, ("fsdp", "no_ckpt"): {}, ("fsdp", "ckpt"): {}}[(ddp, ckpt)] + fsdp = fsdp == "fsdp" + ckpt = ckpt == "ckpt" + world_size = 2 + with temp_files_ctx(num=2) as temp_files: + mp.spawn( + _distributed_worker, (world_size, fsdp, ckpt, temp_files[0], temp_files[1], expected), nprocs=world_size + ) From f56288980956ba5e44c93723b398878bf412c84b Mon Sep 17 00:00:00 2001 From: Min Xu Date: Mon, 26 Apr 2021 18:22:44 -0700 Subject: [PATCH 2/7] add PR number to changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de6e59bef..f099b3382 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## NEXT - TBD ### Added -- FSDP: better memory usage for reduce bucket +- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633)) ## [0.3.6] - 2021-04-26 ### Added From 4dce2d62df4b8aeee35ff156cb83bfa6e6d1c970 Mon Sep 17 00:00:00 2001 From: Min Xu Date: Mon, 26 Apr 2021 18:39:57 -0700 Subject: [PATCH 3/7] filled in with memory number on 1.9 --- tests/nn/data_parallel/test_fsdp_memory.py | 89 +++++++++++++++++++++- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/tests/nn/data_parallel/test_fsdp_memory.py b/tests/nn/data_parallel/test_fsdp_memory.py index baae33ef0..6ed6f3abf 100644 --- a/tests/nn/data_parallel/test_fsdp_memory.py +++ b/tests/nn/data_parallel/test_fsdp_memory.py @@ -122,7 +122,7 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename optimizer = optim.SGD(model.parameters(), lr=1e-4) results = {} - for iteration in range(5): + for iteration in range(3): get_cur_mem(gpu_id, results, f"iter {iteration}: start") out = model(batch) @@ -143,16 +143,97 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") - assert results == expected + assert results == expected, f"{results} but expected {expected}" teardown() @skip_if_single_gpu -@pytest.mark.parametrize("fsdp", ["ddp", "fsdp"]) @pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"]) +@pytest.mark.parametrize("fsdp", ["ddp", "fsdp"]) def test_fsdp_memory(fsdp, ckpt): - expected = {("ddp", "no_ckpt"): {}, ("ddp", "ckpt"): {}, ("fsdp", "no_ckpt"): {}, ("fsdp", "ckpt"): {}}[(ddp, ckpt)] + expected = { + ("ddp", "no_ckpt"): { + "iter 0: start": 9, + "iter 0: after fwd": 346, + "iter 0: after loss": 346, + "iter 0: after bwd": 14, + "iter 0: after step": 14, + "iter 0: done": 9, + "iter 1: start": 9, + "iter 1: after fwd": 346, + "iter 1: after loss": 346, + "iter 1: after bwd": 14, + "iter 1: after step": 14, + "iter 1: done": 9, + "iter 2: start": 9, + "iter 2: after fwd": 346, + "iter 2: after loss": 346, + "iter 2: after bwd": 14, + "iter 2: after step": 14, + "iter 2: done": 9, + }, + ("fsdp", "no_ckpt"): { + "iter 0: start": 3, + "iter 0: after fwd": 340, + "iter 0: after loss": 340, + "iter 0: after bwd": 66, + "iter 0: after step": 66, + "iter 0: done": 3, + "iter 1: start": 3, + "iter 1: after fwd": 340, + "iter 1: after loss": 340, + "iter 1: after bwd": 66, + "iter 1: after step": 66, + "iter 1: done": 3, + "iter 2: start": 3, + "iter 2: after fwd": 340, + "iter 2: after loss": 340, + "iter 2: after bwd": 66, + "iter 2: after step": 66, + "iter 2: done": 3, + }, + ("ddp", "ckpt"): { + "iter 0: start": 9, + "iter 0: after fwd": 57, + "iter 0: after loss": 57, + "iter 0: after bwd": 14, + "iter 0: after step": 14, + "iter 0: done": 9, + "iter 1: start": 9, + "iter 1: after fwd": 57, + "iter 1: after loss": 57, + "iter 1: after bwd": 14, + "iter 1: after step": 14, + "iter 1: done": 9, + "iter 2: start": 9, + "iter 2: after fwd": 57, + "iter 2: after loss": 57, + "iter 2: after bwd": 14, + "iter 2: after step": 14, + "iter 2: done": 9, + }, + ("fsdp", "ckpt"): { + "iter 0: start": 3, + "iter 0: after fwd": 51, + "iter 0: after loss": 51, + "iter 0: after bwd": 66, + "iter 0: after step": 66, + "iter 0: done": 3, + "iter 1: start": 3, + "iter 1: after fwd": 51, + "iter 1: after loss": 51, + "iter 1: after bwd": 66, + "iter 1: after step": 66, + "iter 1: done": 3, + "iter 2: start": 3, + "iter 2: after fwd": 51, + "iter 2: after loss": 51, + "iter 2: after bwd": 66, + "iter 2: after step": 66, + "iter 2: done": 3, + }, + }[(fsdp, ckpt)] fsdp = fsdp == "fsdp" ckpt = ckpt == "ckpt" world_size = 2 From a359af23794241b2f0a60a6874c101041628370c Mon Sep 17 00:00:00 2001 From: Min Xu Date: Tue, 27 Apr 2021 16:19:08 -0700 Subject: [PATCH 4/7] addressed comments --- .../fully_sharded_data_parallel.py | 27 ++++++++++++------- fairscale/utils/reduce_scatter_bucketer.py | 5 ++++ tests/nn/data_parallel/test_fsdp_memory.py | 5 ++-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 6fd21fe78..96617e9ea 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1246,15 +1246,6 @@ def _wait_for_post_backward(self) -> None: else: self.assert_state(TrainingState.BACKWARD_PRE) - def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None: - """Helper used below on all fsdp modules.""" - for p in fsdp_module.params: - if p.requires_grad: - if hasattr(p, "_shard_bwd_hook"): - assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook) - p._shard_bwd_hook[1].remove() - delattr(p, "_shard_bwd_hook") - if self._require_backward_grad_sync: # Flush any unreduced buckets in the post_backward stream. with torch.cuda.stream(self._streams["post_backward"]): @@ -1264,7 +1255,23 @@ def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None: if self.move_grads_to_cpu: # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() - # A backward pass is done, update root and nested FSDP's flags. + + # A backward pass is done, clean up below. + + # Free reducer buffers. + if self._reducer is not None: + self._reducer.teardown() + + def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None: + """Helper used below on all fsdp modules.""" + for p in fsdp_module.params: + if p.requires_grad: + if hasattr(p, "_shard_bwd_hook"): + assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook) + p._shard_bwd_hook[1].remove() + delattr(p, "_shard_bwd_hook") + + # Update root and nested FSDP's hooks and flags. for m in self.modules(): # includes self if isinstance(m, FullyShardedDataParallel): _remove_shard_bwd_hook(m) diff --git a/fairscale/utils/reduce_scatter_bucketer.py b/fairscale/utils/reduce_scatter_bucketer.py index 14eaa8b81..71e37c8ea 100644 --- a/fairscale/utils/reduce_scatter_bucketer.py +++ b/fairscale/utils/reduce_scatter_bucketer.py @@ -150,6 +150,11 @@ def flush(self) -> None: """Reduce-scatter any partial buckets.""" for bucket in self.buckets.values(): bucket.flush() + + @torch.no_grad() + def teardown(self) -> None: + """Free buffers from all buckets.""" + for bucket in self.buckets.values(): bucket.teardown() @functools.lru_cache() diff --git a/tests/nn/data_parallel/test_fsdp_memory.py b/tests/nn/data_parallel/test_fsdp_memory.py index 6ed6f3abf..b1fc1a8db 100644 --- a/tests/nn/data_parallel/test_fsdp_memory.py +++ b/tests/nn/data_parallel/test_fsdp_memory.py @@ -138,9 +138,8 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") - # It is important to use the loop below, not optimizer.zero_grad() to reclaim memory. - for p in model.parameters(): - p.grad = None + # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. + model.zero_grad(set_to_none=True) get_cur_mem(gpu_id, results, f"iter {iteration}: done") assert results == expected, f"{results} but expected {expected}" From 12afd63a899075a21e3efa7ebf997ffe86abed5b Mon Sep 17 00:00:00 2001 From: Min Xu Date: Tue, 27 Apr 2021 16:31:00 -0700 Subject: [PATCH 5/7] update comments --- .../fully_sharded_data_parallel.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 96617e9ea..69051c1e6 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -155,16 +155,15 @@ class FullyShardedDataParallel(nn.Module): *``cpu_offload``*. bucket_cap_mb (int, Optional): FSDP will bucket parameters so that gradient reduction can - be more efficient for small tensors. + be more efficient for small parameters. ``bucket_cap_mb`` controls the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the max shard size is roughly - ``bucket_cap_mb / world_size``. When nested FSDP is used, - each FSDP instance will have a separate set of buckets (1 per bucket - per tensor tuple). Large gradient tensors - are directly reduced without using the buffer. The buffer is there to reduce - communication overhead for small tensors. Overlapping with computation - happens due to use of a different CUDA stream than the computation CUDA - stream. The total memory overhead per buffer is around + ``bucket_cap_mb / world_size``. There is one bucketer (with potentially + multiple ``bucket_cap_mb`` sized buffers shared by all FSDP instances. + Large gradient tensors are directly reduced without using the buffers. + The buffers are there to reduce communication overhead for small tensors. + Overlapping with computation happens due to use of a different CUDA stream + than the computation CUDA stream. The total memory overhead per buffer is around ``bucket_cap_mb / world_size * (world_size + 1)``. The buffers are allocated during the backward pass and freed at the end of the backward pass to save more memory for other phases of the @@ -173,11 +172,12 @@ class FullyShardedDataParallel(nn.Module): from that of the DDP engine. In DDP, the buffer size ``1MB + n*cap_mb``, until n is big enough to cover the entire model size. The order of which buffer is ready there is more rigid and DDP requires all - gradients to be computed in the backward. In FSDP, the buffer - does not change with model size (it scales per # of FSDP instances) - and gradient ready order matters little and we have a final flush - call that ensures everything is reduced and not all gradients need - to be upfront known. Overlapping with compute is done differently too. + gradients to be computed in the backward. In FSDP, the buffer size + does not change with model size (it changes based on number of + tuples) and gradient ready order matters + little since FSDP has a final flush call that ensures everything is reduced + and not all gradients need to be upfront known. Overlapping with compute is + done differently too. Values <= 0 disable bucketing. Default: 25. compute_device (torch.device, Optional): From 2f1a7f018b3b3947a3e2132b7ea638fabe1ad31c Mon Sep 17 00:00:00 2001 From: Min Xu Date: Tue, 27 Apr 2021 17:01:01 -0700 Subject: [PATCH 6/7] fix for 1.6 --- tests/nn/data_parallel/test_fsdp_memory.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/nn/data_parallel/test_fsdp_memory.py b/tests/nn/data_parallel/test_fsdp_memory.py index b1fc1a8db..bb296365a 100644 --- a/tests/nn/data_parallel/test_fsdp_memory.py +++ b/tests/nn/data_parallel/test_fsdp_memory.py @@ -22,7 +22,7 @@ from fairscale.nn import checkpoint_wrapper from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import auto_wrap_bn -from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx +from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version def get_global_group(): @@ -139,7 +139,11 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. - model.zero_grad(set_to_none=True) + if torch_version() >= (1, 7, 0): + model.zero_grad(set_to_none=True) + else: + for p in model.parameters(): + p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") assert results == expected, f"{results} but expected {expected}" From da4082c4de0c5c5e1803a59cf79cd575790b75f4 Mon Sep 17 00:00:00 2001 From: Min Xu Date: Tue, 27 Apr 2021 17:04:12 -0700 Subject: [PATCH 7/7] add a todo --- fairscale/utils/reduce_scatter_bucketer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairscale/utils/reduce_scatter_bucketer.py b/fairscale/utils/reduce_scatter_bucketer.py index 71e37c8ea..7eeec32b9 100644 --- a/fairscale/utils/reduce_scatter_bucketer.py +++ b/fairscale/utils/reduce_scatter_bucketer.py @@ -166,6 +166,10 @@ def _get_shard_size(self, element_size: int, num_shards: int) -> int: return int(bucket_size // num_shards) def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + # TODO (Min): the `group` used here in the key is the object hash, not the content + # hash. That means if FSDP instances are initialized with different process groups, + # even when the group members are in fact the same, we end up creating different + # buckets here. key = (tensor.dtype, tensor.device, group) if key not in self.buckets: # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)