Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1184

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ class FullyShardedDataParallel(nn.Module):
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
optimize_backward_concat (bool):
If True, only let backward pass propagate to self.params, which will
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync
is True (e.g. last microbatch)
Default: False
NOTE: this likely will incur more GPU memory usage
"""

def __init__(
Expand Down Expand Up @@ -371,6 +377,7 @@ def __init__(
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
cast_input: bool = True,
optimize_backward_concat: bool = False,
):
try:
import torch._C
Expand Down Expand Up @@ -496,8 +503,12 @@ def __init__(
param_name_groups = [param_names]
del param_names

self.optimize_backward_concat = optimize_backward_concat
if self.optimize_backward_concat:
assert self.fp32_reduce_scatter, f"{optimize_backward_concat=} requires self.fp32_reduce_scatter=True"

self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory, optimize_backward_concat=self.optimize_backward_concat,
)
del module # free original module in case it helps garbage collection

Expand Down Expand Up @@ -854,6 +865,7 @@ def extra_repr(self) -> str:
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"optimize_backward_concat={self.optimize_backward_concat}"
)
return repr

Expand Down Expand Up @@ -1102,12 +1114,20 @@ def no_sync(self) -> Generator:
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
if self.optimize_backward_concat:
# Set the flag on the wrapped FlattenParamsWrapper module as well,
# so that FlattenParamsWrapper could accumulate grads at corresponding
# leaf nodes without triggering concat operations when gradient
# synchronization is not needed.
m._fsdp_wrapped_module._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
if self.optimize_backward_concat:
m._fsdp_wrapped_module._require_backward_grad_sync = old_flag

@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
Expand Down Expand Up @@ -1744,10 +1764,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
if self.optimize_backward_concat:
# Flatten and concat the accumulated fp32 grads
# and assign them to param.unsharded_main_grad
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = []
else:
param.unsharded_main_grad.add_(param.grad.data)
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
else:
param.unsharded_main_grad.add_(param.grad.data)

param.grad = None

Expand Down Expand Up @@ -1896,7 +1923,16 @@ def _wait_for_post_backward(self) -> None:
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.params]):
self.assert_state(TrainingState.BACKWARD_POST)
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)
else:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)

Expand Down Expand Up @@ -1981,7 +2017,16 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
m.assert_state(TrainingState.BACKWARD_POST)
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
Expand Down
58 changes: 57 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
from fairscale.utils.state_dict import replace_by_prefix_
import functools

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
Expand Down Expand Up @@ -148,6 +149,11 @@ class FlattenParamsWrapper(nn.Module):
flat_param_names (Optional[List[str]]):
originally, give each flat_param a unique name. Note a "flat_param_"
prefix will be added to those names.
optimize_backward_concat (bool):
If True, only let backward pass propagate to the corresponding FSDP.params, which will
invoke the FSDP._post_backward_hook() and concat() op, when _require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
"""

def __init__(
Expand All @@ -157,10 +163,18 @@ def __init__(
flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
optimize_backward_concat: bool = False,
):
super().__init__()
self._fpw_module = module
self.is_flattened = False
self.optimize_backward_concat = optimize_backward_concat
# If optimize_backward_concat == True, used to propagate the
# corresponding FSDP modules's _require_backward_grad_sync flag
self._require_backward_grad_sync = True
# If optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = []

# Handle param_list being None.
if param_list is None:
Expand Down Expand Up @@ -364,18 +378,60 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No
delattr(self, n)
self.flat_params = []

# The post backward hook used to accumulate fp32 gradients
def _grad_accumulation_hook(
self,
grad,
param_index,
):
if self.fp32_grads[param_index] is None:
self.fp32_grads[param_index] = grad.to(torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been a while to look at FSDP code so can be a silly question, but wonder for TransformerEngine modules that keep their weight grads in .main_grad in fp32 precision this can be a duplication.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fp32 weight grads being in .main_grad would be unexpected to me. My understanding was that we are not using that functionality and instead relying on FSDP to do the fp32 gradient accumulation since getting the weights' .main_grad to backprop correctly into the FSDP FlatParameter is tricky.

I think that in the current approach, .main_grad should only be written to after the reduce-scatter:

param.main_grad = reduced_grad.data

However, since the memory usage is higher, maybe somehow there is indeed some duplication with this approach.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I just checked we indeed don't use main_grad: https://github.com/fairinternal/xlformers/blob/main/src/model/te_layers.py#L104

I see https://github.com/fairinternal/xlformers/pull/1418 and #1142 were not merged but something we may want to consider to avoid gradient accumulation overhead.

else:
self.fp32_grads[param_index].add_(grad)
return grad

def _unflatten_params_as_views(self) -> None:
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert self.is_flattened
ps = self.get_param_views()
if self.optimize_backward_concat:
# If self._require_backward_grad_sync == True (e.g. last microbatch),
# we use the original flat_params as autograd leaf nodes and backward
# pass should propagate all the way back to FSDP module and thus invoke
# FSDP post_backward() hook and concat() op
# Otherwise we stop the backward propagation before FSDP module to avoid
# invoking concat() and store the accumulated fp32 grads
if self._require_backward_grad_sync:
ps = self.get_param_views()
else:
with torch.no_grad():
ps = self.get_param_views()
else:
ps = self.get_param_views()

param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
if self.optimize_backward_concat:
# The param_index of parameter p used to accumulate the correspnding
# gradients in self.fp32_grads
param_index = len(param_views)
# Register post backward hook to accumulate the gradients
# in self.fp32_grads
p.register_hook(
functools.partial(
self._grad_accumulation_hook,
param_index=param_index
)
)
param_views.append(p)

if self.optimize_backward_concat and len(self.fp32_grads) == 0:
# Allocate self.fp32_grads at the beginning of each data batch's forward()
self.fp32_grads = [None] * len(param_views)

# Save param views for easy access if anyone still wants to access
# parameters of the module.
setattr(self._fpw_module, "_unflattened_param_views", param_views)
Expand Down
Loading