Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] add fsdp_tips.rst #455

Merged
merged 10 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ API Reference
nn/pipe
nn/sharded_ddp
nn/fsdp
nn/fsdp_tips
nn/misc/checkpoint_activations
7 changes: 5 additions & 2 deletions docs/source/api/nn/fsdp.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
FullyShardedDataParallel
========================
Fully Sharded Data Parallel
=======================================================

See :doc:`FSDP Notes <fsdp_tips>` for a discussion of the principles behind ``FSDP`` and advanced usage.


.. autoclass:: fairscale.nn.FullyShardedDataParallel
:members:
Expand Down
198 changes: 198 additions & 0 deletions docs/source/api/nn/fsdp_tips.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
Fully Sharded Data Parallel Notes
=======================================================
This document describes how ``FSDP`` works, including subtle behaviors that can change performance significantly.
See :doc:`this page <fsdp>` for python docstrings.

Overview
---------

Recent work by `Microsoft <https://arxiv.org/abs/1910.02054>`__ and
`Google <https://arxiv.org/abs/2004.13336>`__ has shown that data
parallel training can be made significantly more efficient by sharding
the model parameters and optimizer state across data parallel workers.
These ideas are encapsulated in the new ``FullyShardedDataParallel``_
(FSDP) wrapper, which is a drop-in replacement for the PyTorch
``DistributedDataParallel`` (DDP) wrapper.

Compared to PyTorch ``DistributedDataParallel``:

* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
* FSDP with ``reshard_after_forward=False`` has the same communication cost as PyTorch DDP and is similar to ZeRO-2
* FSDP with ``reshard_after_forward=True`` increases total communication by 50% and is similar to ZeRO-3:
* all-gather parameters at start of forward pass and start of backward pass
* reduce-scatter grads at end of the backward pass
* In practice, FSDP is faster than DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass.
* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the ``cpu_offload=True`` option, it's possible to train 1T parameter models on 256 GPUs.


General usage notes
--------------------

- For best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True``
- For best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further)
- If you're using ``torch.cuda.amp.autocast`` for mixed precision, that's fully compatible with the FSDP wrapper, just set ``mixed_precision=True``
- If combining with `activation checkpointing <https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/misc/checkpoint_activations.py>`__,
prefer ``FSDP(checkpoint_wrapper(module))`` over ``checkpoint_wrapper(FSDP(module))``. The latter will result in more communication and will be slower.
- Results should be identical to DDP with pointwise Optimizers, e.g.,
Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will
result in slightly different results when using non-pointwise
Optimizers, e.g., Adagrad, Adafactor, LAMB, etc.
- In `fairseq <https://github.com/pytorch/fairseq>`_, FSDP is activated by the command line option ``--ddp-backend=fully_sharded``.

How it works
------------
In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are
summed across workers using an `all-reduce operation <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce>`__.
While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers.

The key insight to unlock full parameter sharding is that we can decompose the
`all-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce>`__
operation in DDP into separate
`all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`__
and
`reduce-scatter <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`__
operations:

.. |Figure 1| image:: https://user-images.githubusercontent.com/23240128/110170085-a67b6280-7dc7-11eb-9128-88d813fc7037.png

|Figure 1|

Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right):

.. |Figure 2| image:: https://user-images.githubusercontent.com/231798/109069252-f9199800-76be-11eb-96f8-86767edf1eb9.png

|Figure 2|

To maximize memory efficiency we can discard the full weights after each
layer's forward pass, saving memory for subsequent layers. This can be
implemented by applying the FSDP wrapper to every layer in your network
(with ``reshard_after_forward=True``). In pseudo-code:

::

FSDP forward pass:
for layer_i in layers:
all-gather full weights for layer_i
forward pass for layer_i
discard full weights for layer_i
FSDP backward pass:
for layer_i in layers:
all-gather full weights for layer_i
backward pass for layer_i
discard full weights for layer_i
reduce-scatter gradients for layer_i

Saving and Loading
------------------

There are two ways to load and save FSDP instances,

- ``state_dict()`` returns a dictionary containing all parameters, which can be loaded with ``load_local_state_dict()``
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
- ``local_state_dict()`` returns a dictionary containing a shard's parameters, which can be loaded with ``load_state_dict()``


Mixed Precision
---------------

When ``mixed_precision=True``:

- Sharded parameters are downcast to ``fp16`` before ``forward``, promoted to ``fp32`` after forward.
- buffers are kept in ``fp16``, unless ``buffer_dtype=torch.float32`` is passed. Buffers are not sharded regardless of arguments.
- By default, gradients will be computed and reduced in FP16. If FP32 reductions are important, set ``fp32_reduce_scatter=True``
- If ``torch.amp.autocast`` is enabled it will override the output dtypes of some operations, like ``BatchNorm2D``


Auto-wrap
~~~~~~~~~
Auto wrapping sub-modules with ``FSDP`` is a convenient way to improve training speed by overlapping the all-gather step across the forward passes of different submodules.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
It also improves memory efficiency by freeing gathered parameters after each layer finishes executing.



.. code-block:: python

import torch
from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import DummyProcessGroup
tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2)

group = DummyProcessGroup(rank=0, size=1)
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, process_group=group, **fsdp_params):

# Wraps layer in FSDP by default if within context
l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(l1, FSDP)
assert l1.mixed_precision and l1.flatten_parameters
# Separately Wraps children modules with more than 1e8 params
tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6)
assert isinstance(l2, nn.Transformer)
for l in l2.encoder.layers:
assert isinstance(l, FSDP)
assert l.mixed_precision and l.flatten_parameters
assert isinstance(l.linear1, FSDP)
assert isinstance(l.linear2, FSDP)
assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped


.. warning:: It is not recommended to use :func:`auto_wrap` with
:class:`FullyShardedDataParallel` on modules that have shared
parameters, as the parameter sharing may be broken (i.e. end up not
shared) if the shared parameters are not (auto-)wrapped under the same
FSDP wrapper instance.


Using CPU RAM
-------------

``move_grads_to_cpu`` and ``cpu_offload`` control which tensors get
moved to CPU.

- ``cpu_offload`` moves weights to CPU when they are not being used.
- ``move_grads_to_cpu`` moves gradients to CPU so that the optimizer step also happens on CPU. This option requires ``cpu_offload=True``.

Gradient Clipping
-----------------

By default,

.. code-block:: python

sharded_module = FullyShardedDataParallel(my_module)
torch.nn.utils.clip_grad_norm_(sharded_module.parameters(), max_norm=1.0)

will use an incorrect norm (the norm over all params in a shard) when
clipping gradients. To overcome this, you can either call
``sharded_module.clip_grad_norm(1.0)`` which does the extra computation
required to compute the norm properly, or use
``torch.nn.utils.clip_grad_value_``.


State Management with extra parameter attributes
------------------------------------------------

We manage several attributes on each Parameter instance. The first two
are set by :func:`_shard_parameters_`:

- ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
- ``_orig_size``: the size of the original Parameter (before sharding)


The remaining attributes are set in ``_init_param_attributes()``:

- ``_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU depending on the value of *``cpu_offload``*.
- ``_fp16_shard``: if ``mixed_precision`` is ``True``, this will be
a single shard of the parameters in FP16, used for all-gather.
- ``_full_param_padded``: the full weight (padded to be evenly divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and only materialized (via all-gather) as needed.

Misc
----
- we don't start the FP32 -> FP16 transfer until after the optimization step completes.
- any direct weight accesses outside of the fwd/bwd, should be in the ``_summon_full_params`` context

3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Components
* `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_
* `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* `Fully Sharded Data Parallel FSDP <../../en/latest/api/nn/fsdp.html>`_
* `FSDP Tips <../../en/latest/api/nn/fsdp_tips.html>`_

* Optimization at scale:
* `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_
Expand All @@ -39,7 +40,7 @@ Components
This library is under active development.
Please be mindful and create an
`issue <https://github.com/facebookresearch/fairscale/issues>`_
if you have any trouble and/or suggestion.
if you have any trouble and/or suggestions.

.. toctree::
:maxdepth: 5
Expand Down
17 changes: 9 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ class FullyShardedDataParallel(nn.Module):
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::

sharded_model = FullyShardedDataParallel(
nn.Sequential( # doesn't have to be nn.Sequential
nn.Linear(5, 100),
FullyShardedDataParallel(nn.Linear(100, 100)),
FullyShardedDataParallel(nn.Linear(100, 100)),
nn.Linear(100, 5),
)
)
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new syntax is enable_wrap(wrapper_cls=FSDP, **fsdp_params)

# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
# Separately Wraps children modules with more than 1e8 params
self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8)

.. warning::

Expand Down