Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix]: let FSDP handle model with multiple forward pass and checkpoint #621

Merged
merged 20 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,15 @@ class FullyShardedDataParallel(nn.Module):
an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP.

.. warning::

If activation checkpointing is used with FSDP, it is strongly encouraged
to use ``checkpoint_wrapper`` function from FairScale instead of the
``checkpoint`` function from PyTorch.

Args:
module (nn.Module):
module to checkpoint
module to be wrapped with FullyShardedDataParallel.
process_group (Optional):
process group for sharding
reshard_after_forward (bool, Optional):
Expand Down Expand Up @@ -207,7 +213,7 @@ def __init__(
self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device

self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

self.numel_padded_per_param: List[int] = []
Expand Down Expand Up @@ -275,11 +281,31 @@ def __init__(
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
)

def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
# Flag to guard multiple pre-forward hook being executed per iteration.
# This is reset at the end of the backward().
self._pre_backward_hook_has_run = False

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2
return factor
factor *= 2
return float(factor)

def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None:
"""Allowing user to override the pre and post divide factors.

Args:
pre (float): divide factor before the reduction.
post (float): divide factor after the reduction.
recursive (bool): recursively set it for all child FSDP instances or not.
"""
self.assert_state(TrainingState.IDLE)
if recursive:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel) and module != self:
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post

@property
def module(self) -> nn.Module:
Expand Down Expand Up @@ -943,7 +969,13 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._use_fp32_param_shard()

# Register pre-backward hooks to all-gather the params for the backward
# pass (if needed).
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
#
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs = self._register_pre_backward_hooks(outputs)

# Done with a forward pass.
Expand All @@ -953,16 +985,18 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward."""
backward. Hooks should be attached to all outputs from the forward.

Returns:
outputs: new outputs with hooks registered if they requires gradient.
"""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled

pre_backward_hook_has_run = [False]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originally, the variable is local to _register_pre_backward_hooks function and it only prevents callbacks registered within this function. When there are multiple calls to this function, this local flag won't prevent multiple callbacks. Moving this flag to self solves this problem.


def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]:
return # only run once
pre_backward_hook_has_run[0] = True
if self._pre_backward_hook_has_run:
return # only run once (from multiple outputs or multiple forward passes)
self._pre_backward_hook_has_run = True

# Start of a backward pass.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
Expand Down Expand Up @@ -1062,13 +1096,27 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
the local optimizer only sees the relevant parameter shard.
"""
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
# then subsequent hook callbacks will see POST state. When checkpoint
# fwd counter is used, IDLE is also possible since the pre-backward hook
# is not triggered (see ``auto_wrap_bn`` below, we have to use
# FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False).
if hasattr(self, "_checkpoint_fwd_counter"):
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE])
else:
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
return

if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a thing? Which gradients need gradients!?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! I think it came from @myleott originally. I would love to know more too. :-)


# If this is a checkpointed module, we check if the following
# counter reaches 0. If not, it is not the final backward call
# for this module yet. Therefore, we early return in that case.
if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"):
if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0:
return

if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
Expand Down Expand Up @@ -1200,6 +1248,7 @@ def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m)
m._pre_backward_hook_has_run = False
if m._has_params:
if any(p.requires_grad for p in m.params):
m.assert_state(TrainingState.BACKWARD_POST)
Expand Down Expand Up @@ -1395,8 +1444,8 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
# In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure.
if self.rank == 0:
print(self)
print(msg)
print(f"Asserting FSDP instance is: {self}")
print(f"ERROR: {msg}")
traceback.print_stack()
raise ValueError(msg)

Expand Down Expand Up @@ -1543,7 +1592,7 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])
v_shard = v[0] if self.rank >= len(v) else v[self.rank]
assert ou.is_singleton_tensor(v_shard)
else:
v_shard = v # dont shard entries that are not tensors
v_shard = v # don't shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard

return full_optim_state_dict
Expand Down Expand Up @@ -1686,6 +1735,10 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int)
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False,
}

with enable_wrap(wrap_bn_only_policy, **fsdp_config):
Expand Down
21 changes: 17 additions & 4 deletions fairscale/nn/misc/checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors

from .misc import patch_batchnorm
from .misc import dec_counter, inc_counter, init_counter, patch_batchnorm


def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module:
def checkpoint_wrapper(
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.

Expand Down Expand Up @@ -58,16 +60,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
Args:
module (nn.Module):
The module to be wrapped
offload_to_cpu (Optional, bool):
offload_to_cpu (bool):
Whether to offload activations to CPU.
maintain_forward_counter (bool):
If True, maintain a forward counter per inner module. The counter will first
increases in forward calls of outer forward pass and then decreases in the
forward calls of outer backward pass. It is used by FullyShardedDataParallel.

Returns:
(nn.Module):
Wrapped module
"""
# Patch the batchnorm layers in case there are any.
# Patch the batchnorm layers in case there are any in this module.
patch_batchnorm(module)

if maintain_forward_counter:
init_counter(module)

# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that.
Expand Down Expand Up @@ -168,6 +177,8 @@ def forward( # type: ignore
with torch.no_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
inc_counter(the_module)

if not isinstance(outputs, torch.Tensor):
# Autograd Functions don't like non-Tensor outputs. We can split the
Expand Down Expand Up @@ -200,6 +211,8 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
the_module = unpacked_args[0]
dec_counter(the_module)

# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)
Expand Down
27 changes: 26 additions & 1 deletion fairscale/nn/misc/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def patch_batchnorm(module: nn.Module) -> List:
(list):
A list of hook handles, late can be freed.
"""
hooks = []

def pre_forward(module: _BatchNorm, input: Tensor) -> None:
if torch.is_grad_enabled():
Expand All @@ -40,6 +39,7 @@ def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None:
return
module.track_running_stats = module._track_running_stats_backup

hooks = []
for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm):
Expand All @@ -48,3 +48,28 @@ def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None:
post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle]
return hooks


def init_counter(module: nn.Module) -> None:
"""Add a checkpoint forward pass counter to a module and all its child FSDP modules.

``inc_counter`` and ``dec_counter`` are used together with this to maintain counters
for FSDP to use in case of multiple forward pass and checkpoint being used at the same time.
"""
for mod in module.modules():
mod._checkpoint_fwd_counter = 0


def _add_counter(module: nn.Module, value: int) -> None:
if not hasattr(module, "_checkpoint_fwd_counter"):
return
for mod in module.modules():
mod._checkpoint_fwd_counter += value


def inc_counter(module: nn.Module) -> None:
_add_counter(module, 1)


def dec_counter(module: nn.Module) -> None:
_add_counter(module, -1)
15 changes: 14 additions & 1 deletion fairscale/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
relative imports.
"""

import contextlib
import functools
import inspect
import logging
Expand All @@ -35,7 +36,7 @@
import subprocess
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy
import pytest
Expand Down Expand Up @@ -645,3 +646,15 @@ def rmf(filename: str) -> None:
os.remove(filename)
except FileNotFoundError:
pass


@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
files = [tempfile.mkstemp()[1] for _ in range(num)]

yield tuple(files)

# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
6 changes: 4 additions & 2 deletions stubs/torch/nn/modules/module.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Module(Generic[T_co]):

def extra_repr(self) -> str: ...

#MODIFIED BY TORCHGPIPE
# This is added by checkpoint_wrapper
_checkpoint_fwd_counter: int

# This is added torchgpipe
training: bool
#END
7 changes: 4 additions & 3 deletions tests/ci_test_list_1.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_freezing_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def temp_files():


@skip_if_single_gpu
def tests1(temp_files):
def test_freezing_weights(temp_files):
world_size = 2
# DDP
fsdp = False
Expand Down
Loading