From 932d1266fa4ba9bb6f1358027a5eee3c5cd613d6 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 1 Mar 2021 13:59:14 -0500 Subject: [PATCH 1/9] copy PR description, start CPU and mixed precision sections --- fairscale/nn/data_parallel/README_fsdp.md | 80 +++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 fairscale/nn/data_parallel/README_fsdp.md diff --git a/fairscale/nn/data_parallel/README_fsdp.md b/fairscale/nn/data_parallel/README_fsdp.md new file mode 100644 index 000000000..6c7171f2b --- /dev/null +++ b/fairscale/nn/data_parallel/README_fsdp.md @@ -0,0 +1,80 @@ +### Overview +Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and [Google](https://arxiv.org/abs/2004.13336) has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. These ideas are encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper, which is a drop-in replacement for PyTorch's `DistributedDataParallel` (DDP) wrapper. + +Compared to PyTorch `DistributedDataParallel` (DDP): +* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs +* FSDP with `reshard_after_forward=False` has the same communication cost as PyTorch DDP and is similar to ZeRO-2 +* FSDP with `reshard_after_forward=True` increases total communication by 50% and is similar to ZeRO-3: + * all-gather parameters at start of forward pass and start of backward pass + * reduce-scatter grads at end of backward pass +* in practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass +* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the `cpu_offload=True` option, it's possible to train 1T parameter models on 256 GPUs. + +### General usage notes +- for best memory efficiency wrap each layer in your network with FSDP and set `reshard_after_forward=True` +- for best training speed set `reshard_after_forward=False` (wrapping each layer is not required, but will improve speed further) +- if you're using `torch.cuda.amp.autocast` for mixed precision, that's fully compatible with the FSDP wrapper, just set `mixed_precision=True` +- if combining with [activation checkpointing](https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/misc/checkpoint_activations.py), prefer `FSDP(checkpoint_wrapper(module))` over `checkpoint_wrapper(FSDP(module))`. The latter will result in more communication and will be slower. +- this is full compatible with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. + +### How it works +In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an [all-reduce operation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce). While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers. + +The key insight to unlock full parameter sharding is that we can decompose the [all-reduce](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) operation in DDP into separate [all-gather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) and [reduce-scatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) operations: + +Screen Shot 2021-01-12 at 12 35 19 PM + +Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right): + +Screen Shot 2021-02-24 at 4 39 55 PM + +To maximize memory efficiency we can discard the full weights after each layer's forward pass, saving memory for subsequent layers. This can be implemented by applying the FSDP wrapper to every layer in your network (with `reshard_after_forward=True`). In pseudo-code: +``` +FSDP forward pass: + for layer_i in layers: + all-gather full weights for layer_i + forward pass for layer_i + discard full weights for layer_i +FSDP backward pass: + for layer_i in layers: + all-gather full weights for layer_i + backward pass for layer_i + discard full weights for layer_i + reduce-scatter gradients for layer_i +``` +#### Mixed Precision + +When `mixed_precision=True`: + +- Sharded parameters are downcast to `fp16` before `forward`, promoted to `fp32` after forward. +- buffers: batch norm not handled in any special way, buffers are kept in `fp16`. Buffers are not sharded regardless of arguments. + +- By default, gradients will be computed and reduced `fp32_reduce_scatter` controls +- FIXME: If `torch.amp.autocast` is enabled it will over-ride the output dtypes of some operations + + +#### Using CPU RAM + +`move_grads_to_cpu` and `cpu_offload` control which tensors get moved to CPU. + +- `cpu_offload` moves weights to CPU when they are not being used. +- `move_grads_to_cpu` moves gradients to CPU. The use of this option requires that the optimizer has a copy of the model parameters on CPU. + +#### Gradient Clipping +By default, +```python +sharded_module = FullyShardedDataParallel(my_module) +torch.nn.utils.clip_grad_norm_(sharded_module.parameters(), max_norm=1.0) +``` +will use an incorrect norm (the norm over all params in a shard) when clipping gradients. +To overcome this, you can either call +`sharded_module.clip_grad_norm(1.0)` +which does the extra computation required to compute the norm properly, or use `torch.nn.utils.clip_grad_value_`. +``` + + + +#### Misc + +- we don't start the FP32 -> FP16 + # transfer until after the optimization step completes. From 6bcf783e90ce967475eb8337159296c97a84e71a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 4 Mar 2021 23:52:09 -0500 Subject: [PATCH 2/9] RST attempt --- docs/source/api/nn/fsdp.rst | 29 ++++ docs/source/api/nn/fsdp_tips.rst | 140 ++++++++++++++++++ docs/source/index.rst | 4 +- .../fully_sharded_data_parallel.py | 17 ++- 4 files changed, 180 insertions(+), 10 deletions(-) create mode 100644 docs/source/api/nn/fsdp_tips.rst diff --git a/docs/source/api/nn/fsdp.rst b/docs/source/api/nn/fsdp.rst index cd31f069d..55fb0cfd1 100644 --- a/docs/source/api/nn/fsdp.rst +++ b/docs/source/api/nn/fsdp.rst @@ -1,6 +1,35 @@ FullyShardedDataParallel ======================== + + +Signatures +========== .. autoclass:: fairscale.nn.FullyShardedDataParallel :members: :undoc-members: + + +Wrapping +========================= + +There are three cases where the `enable_wrap` context can be useful: + +* When you'd like to apply the same parameters to all child modules that you wrap with FSDP. + * Calling the `wrap` function within the `enable_wrap` context will save you from passing the same set of FSDP kwargs explicitly. + * We recommend it since it will also allow more overlapping! + +* When wrapping large models that do NOT fit within the CPU memory. + I.e. you don't first create the full model and then traverse it to wrap it with FSDP at different parts. Instead, you create a wrapped instance of the model incrementally as you build up the model, allowing large modules to be sharded in-place. + +.. code-block:: python + + from fairscale.nn.wrap import auto_wrap, enable_wrap + from fairscale. + fsdp_params = dict(mixed_precision=True, flatten_parameters=True) + with enable_wrap(**fsdp_params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + assert isinstance(self.l1) + # Separately Wraps children modules with more than 1e8 params + self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8) diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst new file mode 100644 index 000000000..22e97d8b9 --- /dev/null +++ b/docs/source/api/nn/fsdp_tips.rst @@ -0,0 +1,140 @@ +Overview +~~~~~~~~ + +Recent work by `Microsoft `__ and +`Google `__ has shown that data +parallel training can be made significantly more efficient by sharding +the model parameters and optimizer state across data parallel workers. +These ideas are encapsulated in the new **``FullyShardedDataParallel`` +(FSDP)** wrapper, which is a drop-in replacement for PyTorch's +``DistributedDataParallel`` (DDP) wrapper. + +Compared to PyTorch ``DistributedDataParallel`` (DDP): + + * FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs + * FSDP with ``reshard_after_forward=False`` has the same communication cost as PyTorch DDP and is similar to ZeRO-2 + * FSDP with``reshard_after_forward=True`` increases total communication by 50% and is similar to ZeRO-3: + * all-gather parameters at start of forward pass and start of backward pass + * reduce-scatter grads at end of backwardpass + * in practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass + * FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the ``cpu_offload=True`` option, it's possible to train 1T parameter models on 256 GPUs. + + +General usage notes +~~~~~~~~~~~~~~~~~~~ + +- for best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True`` +- for best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further) +- if you're using ``torch.cuda.amp.autocast`` for mixed precision, that's fully compatible with the FSDP wrapper, just set ``mixed_precision=True`` +- if combining with `activation checkpointing `__, + prefer ``FSDP(checkpoint_wrapper(module))`` over + ``checkpoint_wrapper(FSDP(module))``. The latter will result in more + communication and will be slower. +- results should be identical to DDP with pointwise Optimizers, e.g., + Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will + result in slightly different results when using non-pointwise + Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. + +How it works +~~~~~~~~~~~~ +In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an `all-reduce operation `__. +While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers. + +The key insight to unlock full parameter sharding is that we can decompose the +`all-reduce `__ +operation in DDP into separate +`all-gather `__ +and +`reduce-scatter `__ +operations: + +-- image:: https://user-images.githubusercontent.com/231798/108780259-26870a00-7536-11eb-890d-51720f39d098.png + :alt: fig1 + + +Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right): + +-- image:: https://user-images.githubusercontent.com/231798/109069252-f9199800-76be-11eb-96f8-86767edf1eb9.png + :alt: fig2 + +To maximize memory efficiency we can discard the full weights after each +layer's forward pass, saving memory for subsequent layers. This can be +implemented by applying the FSDP wrapper to every layer in your network +(with ``reshard_after_forward=True``). In pseudo-code: + +:: + + FSDP forward pass: + for layer_i in layers: + all-gather full weights for layer_i + forward pass for layer_i + discard full weights for layer_i + FSDP backward pass: + for layer_i in layers: + all-gather full weights for layer_i + backward pass for layer_i + discard full weights for layer_i + reduce-scatter gradients for layer_i + +Mixed Precision +^^^^^^^^^^^^^^^ + +When ``mixed_precision=True``: + +- Sharded parameters are downcast to ``fp16`` before ``forward``, + promoted to ``fp32`` after forward. +- buffers: batch norm not handled in any special way, buffers are kept + in ``fp16``. Buffers are not sharded regardless of arguments. + +- By default, gradients will be computed and reduced + ``fp32_reduce_scatter`` controls +- FIXME: If ``torch.amp.autocast`` is enabled it will over-ride the + output dtypes of some operations + +Using CPU RAM +^^^^^^^^^^^^^ + +``move_grads_to_cpu`` and ``cpu_offload`` control which tensors get +moved to CPU. + +- ``cpu_offload`` moves weights to CPU when they are not being used. +- ``move_grads_to_cpu`` moves gradients to CPU. The use of this option + requires that the optimizer has a copy of the model parameters on + CPU. + +Gradient Clipping +^^^^^^^^^^^^^^^^^ + +By default, + +.. code:: python + + sharded_module = FullyShardedDataParallel(my_module) + torch.nn.utils.clip_grad_norm_(sharded_module.parameters(), max_norm=1.0) + +will use an incorrect norm (the norm over all params in a shard) when +clipping gradients. To overcome this, you can either call +``sharded_module.clip_grad_norm(1.0)`` which does the extra computation +required to compute the norm properly, or use +``torch.nn.utils.clip_grad_value_``. + +Auto-wrap +~~~~~~~~~ + +.. code:: python + + from fairscale.nn.wrap import auto_wrap, enable_wrap + from fairscale. + fsdp_params = dict(mixed_precision=True, flatten_parameters=True) + with enable_wrap(**fsdp_params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + assert isinstance(self.l1) + # Separately Wraps children modules with more than 1e8 params + self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8) + +Misc +^^^^ + +- we don't start the FP32 -> FP16 transfer until after the optimization step completes. + diff --git a/docs/source/index.rst b/docs/source/index.rst index e2921b536..f6f80771c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,7 +23,7 @@ Components * `Optimizer state sharding <../../en/latest/api/optim/oss.html>`_ * `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_ * `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_ - * `Fully Sharded Data Parallel FSDP <../../en/latest/api/nn/fsdp.html>`_ + * `Fully Sharded Data Parallel FSDP <../../en/latest/api/nn/fsdp.html>`_, `FSDP Tips <../../en/latest/api/nn/fsdp_tips.html>`_ * Optimization at scale: * `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_ @@ -39,7 +39,7 @@ Components This library is under active development. Please be mindful and create an `issue `_ - if you have any trouble and/or suggestion. + if you have any trouble and/or suggestions. .. toctree:: :maxdepth: 5 diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 730e3a2a3..740193f00 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -79,14 +79,15 @@ class FullyShardedDataParallel(nn.Module): models and to improve training speed by overlapping the all-gather step across the forward pass. For example:: - sharded_model = FullyShardedDataParallel( - nn.Sequential( # doesn't have to be nn.Sequential - nn.Linear(5, 100), - FullyShardedDataParallel(nn.Linear(100, 100)), - FullyShardedDataParallel(nn.Linear(100, 100)), - nn.Linear(100, 5), - ) - ) + from fairscale.nn.auto_wrap import enable_wrap, auto_wrap + from fairscale. + fsdp_params = dict(mixed_precision=True, flatten_parameters=True) + with enable_wrap(**fsdp_params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + assert isinstance(self.l1) + # Separately Wraps children modules with more than 1e8 params + self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8) .. warning:: From 2a2a366c8aa91241dff9546df1831233602e45a9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 5 Mar 2021 00:08:07 -0500 Subject: [PATCH 3/9] cleanup --- docs/source/api/index.rst | 1 + docs/source/api/nn/fsdp_tips.rst | 26 +++++++++++++++----------- docs/source/index.rst | 3 ++- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index e04179bdf..45a59ddc9 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -10,4 +10,5 @@ API Reference nn/pipe nn/sharded_ddp nn/fsdp + nn/fsdp_tips nn/misc/checkpoint_activations diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst index 22e97d8b9..af898ec3b 100644 --- a/docs/source/api/nn/fsdp_tips.rst +++ b/docs/source/api/nn/fsdp_tips.rst @@ -1,5 +1,9 @@ +FSDP Tips and Tricks +======================================== + + Overview -~~~~~~~~ +------- Recent work by `Microsoft `__ and `Google `__ has shown that data @@ -11,17 +15,17 @@ These ideas are encapsulated in the new **``FullyShardedDataParallel`` Compared to PyTorch ``DistributedDataParallel`` (DDP): - * FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs - * FSDP with ``reshard_after_forward=False`` has the same communication cost as PyTorch DDP and is similar to ZeRO-2 - * FSDP with``reshard_after_forward=True`` increases total communication by 50% and is similar to ZeRO-3: - * all-gather parameters at start of forward pass and start of backward pass - * reduce-scatter grads at end of backwardpass - * in practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass - * FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the ``cpu_offload=True`` option, it's possible to train 1T parameter models on 256 GPUs. +* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs +* FSDP with ``reshard_after_forward=False`` has the same communication cost as PyTorch DDP and is similar to ZeRO-2 +* FSDP with ``reshard_after_forward=True`` increases total communication by 50% and is similar to ZeRO-3: + * all-gather parameters at start of forward pass and start of backward pass + * reduce-scatter grads at end of backwardpass +* In practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass. +* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the ``cpu_offload=True`` option, it's possible to train 1T parameter models on 256 GPUs. General usage notes -~~~~~~~~~~~~~~~~~~~ +------------------ - for best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True`` - for best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further) @@ -36,7 +40,7 @@ General usage notes Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. How it works -~~~~~~~~~~~~ +------------ In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an `all-reduce operation `__. While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers. @@ -77,7 +81,7 @@ implemented by applying the FSDP wrapper to every layer in your network reduce-scatter gradients for layer_i Mixed Precision -^^^^^^^^^^^^^^^ +-------------- When ``mixed_precision=True``: diff --git a/docs/source/index.rst b/docs/source/index.rst index f6f80771c..52769b251 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,7 +23,8 @@ Components * `Optimizer state sharding <../../en/latest/api/optim/oss.html>`_ * `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_ * `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_ - * `Fully Sharded Data Parallel FSDP <../../en/latest/api/nn/fsdp.html>`_, `FSDP Tips <../../en/latest/api/nn/fsdp_tips.html>`_ + * `Fully Sharded Data Parallel FSDP <../../en/latest/api/nn/fsdp.html>`_ + * `FSDP Tips <../../en/latest/api/nn/fsdp_tips.html>`_ * Optimization at scale: * `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_ From ada0d97c1768e0a8431af25aeb161ca115a24f03 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 5 Mar 2021 10:45:55 -0500 Subject: [PATCH 4/9] Cross link, fix typos --- docs/source/api/nn/fsdp.rst | 28 +----- docs/source/api/nn/fsdp_tips.rst | 150 +++++++++++++++++++++---------- 2 files changed, 102 insertions(+), 76 deletions(-) diff --git a/docs/source/api/nn/fsdp.rst b/docs/source/api/nn/fsdp.rst index 55fb0cfd1..97046e856 100644 --- a/docs/source/api/nn/fsdp.rst +++ b/docs/source/api/nn/fsdp.rst @@ -1,35 +1,9 @@ FullyShardedDataParallel ======================== +See :doc:`FSDP Notes ` for a discussion of the principles behind ``FSDP`` and advanced usage. -Signatures -========== .. autoclass:: fairscale.nn.FullyShardedDataParallel :members: :undoc-members: - - -Wrapping -========================= - -There are three cases where the `enable_wrap` context can be useful: - -* When you'd like to apply the same parameters to all child modules that you wrap with FSDP. - * Calling the `wrap` function within the `enable_wrap` context will save you from passing the same set of FSDP kwargs explicitly. - * We recommend it since it will also allow more overlapping! - -* When wrapping large models that do NOT fit within the CPU memory. - I.e. you don't first create the full model and then traverse it to wrap it with FSDP at different parts. Instead, you create a wrapped instance of the model incrementally as you build up the model, allowing large modules to be sharded in-place. - -.. code-block:: python - - from fairscale.nn.wrap import auto_wrap, enable_wrap - from fairscale. - fsdp_params = dict(mixed_precision=True, flatten_parameters=True) - with enable_wrap(**fsdp_params): - # Wraps layer in FSDP by default if within context - self.l1 = wrap(torch.nn.Linear(5, 5)) - assert isinstance(self.l1) - # Separately Wraps children modules with more than 1e8 params - self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8) diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst index af898ec3b..01cd31fb7 100644 --- a/docs/source/api/nn/fsdp_tips.rst +++ b/docs/source/api/nn/fsdp_tips.rst @@ -1,47 +1,48 @@ -FSDP Tips and Tricks +FSDP Notes ======================================== - +This document describes how `FSDP` works, including subtle behaviors that can change performance significantly. +See :doc:`FullyShardedDataParallel ` for python docstrings. Overview -------- +--------- Recent work by `Microsoft `__ and `Google `__ has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. -These ideas are encapsulated in the new **``FullyShardedDataParallel`` -(FSDP)** wrapper, which is a drop-in replacement for PyTorch's +These ideas are encapsulated in the new ``FullyShardedDataParallel``_ +(FSDP) wrapper, which is a drop-in replacement for the PyTorch ``DistributedDataParallel`` (DDP) wrapper. -Compared to PyTorch ``DistributedDataParallel`` (DDP): +Compared to PyTorch ``DistributedDataParallel``: * FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs * FSDP with ``reshard_after_forward=False`` has the same communication cost as PyTorch DDP and is similar to ZeRO-2 * FSDP with ``reshard_after_forward=True`` increases total communication by 50% and is similar to ZeRO-3: * all-gather parameters at start of forward pass and start of backward pass - * reduce-scatter grads at end of backwardpass -* In practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass. + * reduce-scatter grads at end of the backward pass +* In practice, FSDP is faster than DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass. * FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the ``cpu_offload=True`` option, it's possible to train 1T parameter models on 256 GPUs. General usage notes ------------------- - -- for best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True`` -- for best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further) -- if you're using ``torch.cuda.amp.autocast`` for mixed precision, that's fully compatible with the FSDP wrapper, just set ``mixed_precision=True`` -- if combining with `activation checkpointing `__, - prefer ``FSDP(checkpoint_wrapper(module))`` over - ``checkpoint_wrapper(FSDP(module))``. The latter will result in more - communication and will be slower. -- results should be identical to DDP with pointwise Optimizers, e.g., +-------------------- + +- For best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True`` +- For best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further) +- If you're using ``torch.cuda.amp.autocast`` for mixed precision, that's fully compatible with the FSDP wrapper, just set ``mixed_precision=True`` +- Uf combining with `activation checkpointing `__, + prefer ``FSDP(checkpoint_wrapper(module))`` over ``checkpoint_wrapper(FSDP(module))``. The latter will result in more communication and will be slower. +- Results should be identical to DDP with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. + How it works ------------ -In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an `all-reduce operation `__. +In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are +summed across workers using an `all-reduce operation `__. While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers. The key insight to unlock full parameter sharding is that we can decompose the @@ -52,14 +53,14 @@ and `reduce-scatter `__ operations: --- image:: https://user-images.githubusercontent.com/231798/108780259-26870a00-7536-11eb-890d-51720f39d098.png - :alt: fig1 +.. |Figure 1| image:: https://user-images.githubusercontent.com/231798/108780259-26870a00-7536-11eb-890d-51720f39d098.png Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right): --- image:: https://user-images.githubusercontent.com/231798/109069252-f9199800-76be-11eb-96f8-86767edf1eb9.png - :alt: fig2 +.. |Figure 2| image:: https://user-images.githubusercontent.com/231798/109069252-f9199800-76be-11eb-96f8-86767edf1eb9.png + +|Figure 2| To maximize memory efficiency we can discard the full weights after each layer's forward pass, saving memory for subsequent layers. This can be @@ -80,23 +81,69 @@ implemented by applying the FSDP wrapper to every layer in your network discard full weights for layer_i reduce-scatter gradients for layer_i +Saving and Loading +------------------ + +There are two ways to load and save FSDP instances, + +- ``state_dict()`` returns a dictionary containing all parameters, which can be loaded with ``load_local_state_dict()`` +- ``local_state_dict()`` returns a dictionary containing a shard's parameters, which can be loaded with ``load_local_state_dict()`` + + Mixed Precision --------------- +--------------- When ``mixed_precision=True``: -- Sharded parameters are downcast to ``fp16`` before ``forward``, - promoted to ``fp32`` after forward. -- buffers: batch norm not handled in any special way, buffers are kept - in ``fp16``. Buffers are not sharded regardless of arguments. - +- Sharded parameters are downcast to ``fp16`` before ``forward``, promoted to ``fp32`` after forward. +- buffers are kept in ``fp16``, unless ``buffer_dtype=torch.float32`` is passed. Buffers are not sharded regardless of arguments. - By default, gradients will be computed and reduced - ``fp32_reduce_scatter`` controls -- FIXME: If ``torch.amp.autocast`` is enabled it will over-ride the - output dtypes of some operations +- ``fp32_reduce_scatter=True`` controls the quantization of the gradient communication +- If ``torch.amp.autocast`` is enabled it will override the output dtypes of some operations, like ``BatchNorm2D`` + + +Auto-wrap +~~~~~~~~~ +Auto wrapping sub-modules with ``FSDP`` is a convenient way to improve training speed by overlapping the all-gather step across the forward passes of different submodules. + + + +.. code-block:: python + + import torch + from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap + from fairscale.nn.data_parallel import FullyShardedDataParallel + from fairscale.utils.testing import DummyProcessGroup + tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2) + + group = DummyProcessGroup(rank=0, size=1) + fsdp_params = dict(mixed_precision=True, flatten_parameters=True) + with enable_wrap(process_group=group, **fsdp_params): + + # Wraps layer in FSDP by default if within context + l1 = wrap(torch.nn.Linear(5, 5)) + assert isinstance(l1, FullyShardedDataParallel) + assert l1.mixed_precision and l1.flatten_parameters + # Separately Wraps children modules with more than 1e8 params + tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6) + assert isinstance(l2, nn.Transformer) + for l in l2.encoder.layers: + assert isinstance(l, FullyShardedDataParallel) + assert l.mixed_precision and l.flatten_parameters + assert isinstance(l.linear1, FullyShardedDataParallel) + assert isinstance(l.linear2, FullyShardedDataParallel) + assert not isinstance(l.self_attn, FullyShardedDataParallel) # self attention is not auto-wrapped + + +.. warning:: It is not recommended to use :func:`auto_wrap` with + :class:`FullyShardedDataParallel` on modules that have shared + parameters, as the parameter sharing may be broken (i.e. end up not + shared) if the shared parameters are not (auto-)wrapped under the same + FSDP wrapper instance. + Using CPU RAM -^^^^^^^^^^^^^ +------------- ``move_grads_to_cpu`` and ``cpu_offload`` control which tensors get moved to CPU. @@ -107,11 +154,11 @@ moved to CPU. CPU. Gradient Clipping -^^^^^^^^^^^^^^^^^ +----------------- By default, -.. code:: python +.. code-block:: python sharded_module = FullyShardedDataParallel(my_module) torch.nn.utils.clip_grad_norm_(sharded_module.parameters(), max_norm=1.0) @@ -122,23 +169,28 @@ clipping gradients. To overcome this, you can either call required to compute the norm properly, or use ``torch.nn.utils.clip_grad_value_``. -Auto-wrap -~~~~~~~~~ -.. code:: python +State Management with extra parameter attributes +------------------------------------------------ - from fairscale.nn.wrap import auto_wrap, enable_wrap - from fairscale. - fsdp_params = dict(mixed_precision=True, flatten_parameters=True) - with enable_wrap(**fsdp_params): - # Wraps layer in FSDP by default if within context - self.l1 = wrap(torch.nn.Linear(5, 5)) - assert isinstance(self.l1) - # Separately Wraps children modules with more than 1e8 params - self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8) +We manage several attributes on each Parameter instance. The first two +are set by :func:`_shard_parameters_`: -Misc -^^^^ + ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` + if the Parameter is intentionally not sharded (in which case we + will all-reduce grads for this param). + ``_orig_size``: the size of the original Parameter (before sharding) +The remaining attributes are set in :func:``_init_param_attributes()``: + ``_fp32_shard``: a single shard of the parameters in full precision + (typically FP32, but this is dependent on the dtype of the model + as it's passed in by the user). This can be on CPU or GPU depending on the value of *``cpu_offload``*. + ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be + a single shard of the parameters in FP16, used for all-gather. + ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the + forward and backward pass. This will be resized in place and only materialized (via all-gather) as needed. + +Misc +---- - we don't start the FP32 -> FP16 transfer until after the optimization step completes. From 0733d5943d6d6df9700382f933dc134780cd3242 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 5 Mar 2021 10:48:52 -0500 Subject: [PATCH 5/9] Remove .md --- fairscale/nn/data_parallel/README_fsdp.md | 80 ----------------------- 1 file changed, 80 deletions(-) delete mode 100644 fairscale/nn/data_parallel/README_fsdp.md diff --git a/fairscale/nn/data_parallel/README_fsdp.md b/fairscale/nn/data_parallel/README_fsdp.md deleted file mode 100644 index 6c7171f2b..000000000 --- a/fairscale/nn/data_parallel/README_fsdp.md +++ /dev/null @@ -1,80 +0,0 @@ -### Overview -Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and [Google](https://arxiv.org/abs/2004.13336) has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. These ideas are encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper, which is a drop-in replacement for PyTorch's `DistributedDataParallel` (DDP) wrapper. - -Compared to PyTorch `DistributedDataParallel` (DDP): -* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs -* FSDP with `reshard_after_forward=False` has the same communication cost as PyTorch DDP and is similar to ZeRO-2 -* FSDP with `reshard_after_forward=True` increases total communication by 50% and is similar to ZeRO-3: - * all-gather parameters at start of forward pass and start of backward pass - * reduce-scatter grads at end of backward pass -* in practice, FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass -* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the `cpu_offload=True` option, it's possible to train 1T parameter models on 256 GPUs. - -### General usage notes -- for best memory efficiency wrap each layer in your network with FSDP and set `reshard_after_forward=True` -- for best training speed set `reshard_after_forward=False` (wrapping each layer is not required, but will improve speed further) -- if you're using `torch.cuda.amp.autocast` for mixed precision, that's fully compatible with the FSDP wrapper, just set `mixed_precision=True` -- if combining with [activation checkpointing](https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/misc/checkpoint_activations.py), prefer `FSDP(checkpoint_wrapper(module))` over `checkpoint_wrapper(FSDP(module))`. The latter will result in more communication and will be slower. -- this is full compatible with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. - -### How it works -In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an [all-reduce operation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce). While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers. - -The key insight to unlock full parameter sharding is that we can decompose the [all-reduce](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) operation in DDP into separate [all-gather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) and [reduce-scatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) operations: - -Screen Shot 2021-01-12 at 12 35 19 PM - -Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right): - -Screen Shot 2021-02-24 at 4 39 55 PM - -To maximize memory efficiency we can discard the full weights after each layer's forward pass, saving memory for subsequent layers. This can be implemented by applying the FSDP wrapper to every layer in your network (with `reshard_after_forward=True`). In pseudo-code: -``` -FSDP forward pass: - for layer_i in layers: - all-gather full weights for layer_i - forward pass for layer_i - discard full weights for layer_i -FSDP backward pass: - for layer_i in layers: - all-gather full weights for layer_i - backward pass for layer_i - discard full weights for layer_i - reduce-scatter gradients for layer_i -``` -#### Mixed Precision - -When `mixed_precision=True`: - -- Sharded parameters are downcast to `fp16` before `forward`, promoted to `fp32` after forward. -- buffers: batch norm not handled in any special way, buffers are kept in `fp16`. Buffers are not sharded regardless of arguments. - -- By default, gradients will be computed and reduced `fp32_reduce_scatter` controls -- FIXME: If `torch.amp.autocast` is enabled it will over-ride the output dtypes of some operations - - -#### Using CPU RAM - -`move_grads_to_cpu` and `cpu_offload` control which tensors get moved to CPU. - -- `cpu_offload` moves weights to CPU when they are not being used. -- `move_grads_to_cpu` moves gradients to CPU. The use of this option requires that the optimizer has a copy of the model parameters on CPU. - -#### Gradient Clipping -By default, -```python -sharded_module = FullyShardedDataParallel(my_module) -torch.nn.utils.clip_grad_norm_(sharded_module.parameters(), max_norm=1.0) -``` -will use an incorrect norm (the norm over all params in a shard) when clipping gradients. -To overcome this, you can either call -`sharded_module.clip_grad_norm(1.0)` -which does the extra computation required to compute the norm properly, or use `torch.nn.utils.clip_grad_value_`. -``` - - - -#### Misc - -- we don't start the FP32 -> FP16 - # transfer until after the optimization step completes. From e86920c77fa7507ded5971a6445773ac0b6cf93e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 7 Mar 2021 16:47:46 -0500 Subject: [PATCH 6/9] Fix warning, new image --- docs/source/api/nn/fsdp_tips.rst | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst index 01cd31fb7..eca39a848 100644 --- a/docs/source/api/nn/fsdp_tips.rst +++ b/docs/source/api/nn/fsdp_tips.rst @@ -31,7 +31,7 @@ General usage notes - For best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True`` - For best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further) - If you're using ``torch.cuda.amp.autocast`` for mixed precision, that's fully compatible with the FSDP wrapper, just set ``mixed_precision=True`` -- Uf combining with `activation checkpointing `__, +- If combining with `activation checkpointing `__, prefer ``FSDP(checkpoint_wrapper(module))`` over ``checkpoint_wrapper(FSDP(module))``. The latter will result in more communication and will be slower. - Results should be identical to DDP with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will @@ -53,8 +53,9 @@ and `reduce-scatter `__ operations: -.. |Figure 1| image:: https://user-images.githubusercontent.com/231798/108780259-26870a00-7536-11eb-890d-51720f39d098.png +.. |Figure 1| image:: https://user-images.githubusercontent.com/23240128/110170085-a67b6280-7dc7-11eb-9128-88d813fc7037.png +|Figure 1| Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right): @@ -176,21 +177,24 @@ State Management with extra parameter attributes We manage several attributes on each Parameter instance. The first two are set by :func:`_shard_parameters_`: - ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` - if the Parameter is intentionally not sharded (in which case we - will all-reduce grads for this param). - ``_orig_size``: the size of the original Parameter (before sharding) +- ``_is_sharded``: ``True`` if the Parameter is sharded or ``False`` + if the Parameter is intentionally not sharded (in which case we + will all-reduce grads for this param). +- ``_orig_size``: the size of the original Parameter (before sharding) -The remaining attributes are set in :func:``_init_param_attributes()``: - ``_fp32_shard``: a single shard of the parameters in full precision - (typically FP32, but this is dependent on the dtype of the model - as it's passed in by the user). This can be on CPU or GPU depending on the value of *``cpu_offload``*. - ``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be - a single shard of the parameters in FP16, used for all-gather. - ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the - forward and backward pass. This will be resized in place and only materialized (via all-gather) as needed. + +The remaining attributes are set in ``_init_param_attributes()``: + +- ``_fp32_shard``: a single shard of the parameters in full precision + (typically FP32, but this is dependent on the dtype of the model + as it's passed in by the user). This can be on CPU or GPU depending on the value of *``cpu_offload``*. +- ``_fp16_shard``: if ``mixed_precision`` is ``True``, this will be + a single shard of the parameters in FP16, used for all-gather. +- ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the + forward and backward pass. This will be resized in place and only materialized (via all-gather) as needed. Misc ---- - we don't start the FP32 -> FP16 transfer until after the optimization step completes. +- any direct weight accesses outside of the fwd/bwd, should be in the ``_summon_full_params`` context From 3d60cb78496ec2a469766a24a3a1b8f0fe2e04b8 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 8 Mar 2021 13:17:31 -0500 Subject: [PATCH 7/9] link fairseq --- docs/source/api/nn/fsdp_tips.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst index eca39a848..41a290fbf 100644 --- a/docs/source/api/nn/fsdp_tips.rst +++ b/docs/source/api/nn/fsdp_tips.rst @@ -1,6 +1,6 @@ FSDP Notes ======================================== -This document describes how `FSDP` works, including subtle behaviors that can change performance significantly. +This document describes how ``FSDP`` works, including subtle behaviors that can change performance significantly. See :doc:`FullyShardedDataParallel ` for python docstrings. Overview @@ -37,7 +37,7 @@ General usage notes Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. - +- In `fairseq `_, FSDP is activated by the command line option ``--ddp-backend=fully_sharded``. How it works ------------ From d07308830a0e086578e5eec530304bdfdb9ac16b Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 8 Mar 2021 15:12:06 -0500 Subject: [PATCH 8/9] better anchor text --- docs/source/api/nn/fsdp.rst | 4 ++-- docs/source/api/nn/fsdp_tips.rst | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/api/nn/fsdp.rst b/docs/source/api/nn/fsdp.rst index 97046e856..6db33b229 100644 --- a/docs/source/api/nn/fsdp.rst +++ b/docs/source/api/nn/fsdp.rst @@ -1,5 +1,5 @@ -FullyShardedDataParallel -======================== +Fully Sharded Data Parallel +======================================================= See :doc:`FSDP Notes ` for a discussion of the principles behind ``FSDP`` and advanced usage. diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst index 41a290fbf..480508371 100644 --- a/docs/source/api/nn/fsdp_tips.rst +++ b/docs/source/api/nn/fsdp_tips.rst @@ -1,7 +1,7 @@ -FSDP Notes -======================================== +Fully Sharded Data Parallel Notes +======================================================= This document describes how ``FSDP`` works, including subtle behaviors that can change performance significantly. -See :doc:`FullyShardedDataParallel ` for python docstrings. +See :doc:`this page ` for python docstrings. Overview --------- From d5060e9df36de8d3e17b7d40191a5ed66c3ab6aa Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 8 Mar 2021 15:18:12 -0500 Subject: [PATCH 9/9] Thanks for the comments, Myle --- docs/source/api/nn/fsdp_tips.rst | 24 +++++++++---------- .../fully_sharded_data_parallel.py | 4 ++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/docs/source/api/nn/fsdp_tips.rst b/docs/source/api/nn/fsdp_tips.rst index 480508371..5d68358f3 100644 --- a/docs/source/api/nn/fsdp_tips.rst +++ b/docs/source/api/nn/fsdp_tips.rst @@ -88,7 +88,7 @@ Saving and Loading There are two ways to load and save FSDP instances, - ``state_dict()`` returns a dictionary containing all parameters, which can be loaded with ``load_local_state_dict()`` -- ``local_state_dict()`` returns a dictionary containing a shard's parameters, which can be loaded with ``load_local_state_dict()`` +- ``local_state_dict()`` returns a dictionary containing a shard's parameters, which can be loaded with ``load_state_dict()`` Mixed Precision @@ -98,14 +98,14 @@ When ``mixed_precision=True``: - Sharded parameters are downcast to ``fp16`` before ``forward``, promoted to ``fp32`` after forward. - buffers are kept in ``fp16``, unless ``buffer_dtype=torch.float32`` is passed. Buffers are not sharded regardless of arguments. -- By default, gradients will be computed and reduced -- ``fp32_reduce_scatter=True`` controls the quantization of the gradient communication +- By default, gradients will be computed and reduced in FP16. If FP32 reductions are important, set ``fp32_reduce_scatter=True`` - If ``torch.amp.autocast`` is enabled it will override the output dtypes of some operations, like ``BatchNorm2D`` Auto-wrap ~~~~~~~~~ Auto wrapping sub-modules with ``FSDP`` is a convenient way to improve training speed by overlapping the all-gather step across the forward passes of different submodules. +It also improves memory efficiency by freeing gathered parameters after each layer finishes executing. @@ -113,27 +113,27 @@ Auto wrapping sub-modules with ``FSDP`` is a convenient way to improve training import torch from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap - from fairscale.nn.data_parallel import FullyShardedDataParallel + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.utils.testing import DummyProcessGroup tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2) group = DummyProcessGroup(rank=0, size=1) fsdp_params = dict(mixed_precision=True, flatten_parameters=True) - with enable_wrap(process_group=group, **fsdp_params): + with enable_wrap(wrapper_cls=FSDP, process_group=group, **fsdp_params): # Wraps layer in FSDP by default if within context l1 = wrap(torch.nn.Linear(5, 5)) - assert isinstance(l1, FullyShardedDataParallel) + assert isinstance(l1, FSDP) assert l1.mixed_precision and l1.flatten_parameters # Separately Wraps children modules with more than 1e8 params tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6) assert isinstance(l2, nn.Transformer) for l in l2.encoder.layers: - assert isinstance(l, FullyShardedDataParallel) + assert isinstance(l, FSDP) assert l.mixed_precision and l.flatten_parameters - assert isinstance(l.linear1, FullyShardedDataParallel) - assert isinstance(l.linear2, FullyShardedDataParallel) - assert not isinstance(l.self_attn, FullyShardedDataParallel) # self attention is not auto-wrapped + assert isinstance(l.linear1, FSDP) + assert isinstance(l.linear2, FSDP) + assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped .. warning:: It is not recommended to use :func:`auto_wrap` with @@ -150,9 +150,7 @@ Using CPU RAM moved to CPU. - ``cpu_offload`` moves weights to CPU when they are not being used. -- ``move_grads_to_cpu`` moves gradients to CPU. The use of this option - requires that the optimizer has a copy of the model parameters on - CPU. +- ``move_grads_to_cpu`` moves gradients to CPU so that the optimizer step also happens on CPU. This option requires ``cpu_offload=True``. Gradient Clipping ----------------- diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 740193f00..a44b032c4 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -80,12 +80,12 @@ class FullyShardedDataParallel(nn.Module): across the forward pass. For example:: from fairscale.nn.auto_wrap import enable_wrap, auto_wrap - from fairscale. + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP fsdp_params = dict(mixed_precision=True, flatten_parameters=True) with enable_wrap(**fsdp_params): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) - assert isinstance(self.l1) + assert isinstance(self.l1, FSDP) # Separately Wraps children modules with more than 1e8 params self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8)