diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index bfa02ae77..dfe17dcdc 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -340,6 +340,12 @@ class FullyShardedDataParallel(nn.Module): rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM. Default: False + optimize_backward_concat (bool): + If True, only let backward pass propagate to self.params, which will + invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync + is True (e.g. last microbatch) + Default: False + NOTE: this likely will incur more GPU memory usage """ def __init__( @@ -371,6 +377,7 @@ def __init__( limit_all_gather_events: bool = False, limit_reduce_scatter_events: bool = False, cast_input: bool = True, + optimize_backward_concat: bool = False, ): try: import torch._C @@ -496,8 +503,12 @@ def __init__( param_name_groups = [param_names] del param_names + self.optimize_backward_concat = optimize_backward_concat + if self.optimize_backward_concat: + assert self.fp32_reduce_scatter, f"{optimize_backward_concat=} requires self.fp32_reduce_scatter=True" + self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( - module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory + module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory, optimize_backward_concat=self.optimize_backward_concat, ) del module # free original module in case it helps garbage collection @@ -854,6 +865,7 @@ def extra_repr(self) -> str: f"bucket_cap_mb={self.bucket_cap_mb}, " f"clear_autocast_cache={self.clear_autocast_cache}" f"force_input_to_fp32={self.force_input_to_fp32}" + f"optimize_backward_concat={self.optimize_backward_concat}" ) return repr @@ -1102,12 +1114,20 @@ def no_sync(self) -> Generator: if isinstance(m, FullyShardedDataParallel): old_flags.append((m, m._require_backward_grad_sync)) m._require_backward_grad_sync = False + if self.optimize_backward_concat: + # Set the flag on the wrapped FlattenParamsWrapper module as well, + # so that FlattenParamsWrapper could accumulate grads at corresponding + # leaf nodes without triggering concat operations when gradient + # synchronization is not needed. + m._fsdp_wrapped_module._require_backward_grad_sync = False try: yield finally: for m, old_flag in old_flags: assert m._require_backward_grad_sync is False m._require_backward_grad_sync = old_flag + if self.optimize_backward_concat: + m._fsdp_wrapped_module._require_backward_grad_sync = old_flag @contextlib.contextmanager def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator: @@ -1744,10 +1764,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self._use_fp32_param_shard([param]) if self.fp32_reduce_scatter: - if getattr(param, "unsharded_main_grad", None) is None: - param.unsharded_main_grad = param.grad.to(torch.float32) + if self.optimize_backward_concat: + # Flatten and concat the accumulated fp32 grads + # and assign them to param.unsharded_main_grad + param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads]) + # Clean up accumulated grads between data batches + self._fsdp_wrapped_module.fp32_grads = [] else: - param.unsharded_main_grad.add_(param.grad.data) + if getattr(param, "unsharded_main_grad", None) is None: + param.unsharded_main_grad = param.grad.to(torch.float32) + else: + param.unsharded_main_grad.add_(param.grad.data) param.grad = None @@ -1896,7 +1923,16 @@ def _wait_for_post_backward(self) -> None: # all the params, the post_backward hook will not fire and the # state will remain in `TrainingState.BACKWARD_PRE`. if any([p.requires_grad for p in self.params]): - self.assert_state(TrainingState.BACKWARD_POST) + if self.optimize_backward_concat: + # If self.optimize_backward_concat==True, FSDP backward should + # only be triggered (which will invoke concat()) + # when self._fsdp_wrapped_module._require_backward_grad_sync = True + if self._fsdp_wrapped_module._require_backward_grad_sync: + self.assert_state(TrainingState.BACKWARD_POST) + else: + self.assert_state(TrainingState.BACKWARD_PRE) + else: + self.assert_state(TrainingState.BACKWARD_POST) else: self.assert_state(TrainingState.BACKWARD_PRE) @@ -1981,7 +2017,16 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: # all the params, the post_backward hook will not fire and the # state will remain in `TrainingState.BACKWARD_PRE`. if any([p.requires_grad for p in m.params]): - m.assert_state(TrainingState.BACKWARD_POST) + if self.optimize_backward_concat: + # If self.optimize_backward_concat==True, FSDP backward should + # only be triggered (which will invoke concat()) + # when self._fsdp_wrapped_module._require_backward_grad_sync = True + if self._fsdp_wrapped_module._require_backward_grad_sync: + m.assert_state(TrainingState.BACKWARD_POST) + else: + m.assert_state(TrainingState.BACKWARD_PRE) + else: + m.assert_state(TrainingState.BACKWARD_POST) else: m.assert_state(TrainingState.BACKWARD_PRE) else: diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 38265dd2b..3b0b9f6e5 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -33,6 +33,7 @@ from fairscale.experimental.nn.ssd_offload import SsdFlatParameter from fairscale.utils.state_dict import replace_by_prefix_ +import functools if TYPE_CHECKING: from collections import OrderedDict # noqa: F401 @@ -148,6 +149,11 @@ class FlattenParamsWrapper(nn.Module): flat_param_names (Optional[List[str]]): originally, give each flat_param a unique name. Note a "flat_param_" prefix will be added to those names. + optimize_backward_concat (bool): + If True, only let backward pass propagate to the corresponding FSDP.params, which will + invoke the FSDP._post_backward_hook() and concat() op, when _require_backward_grad_sync + is True (e.g. last microbatch) + NOTE: this likely will incur more GPU memory usage """ def __init__( @@ -157,10 +163,18 @@ def __init__( flat_param_names: Optional[List[str]] = None, ssd_offload: bool = False, ssd_directory: str = "", + optimize_backward_concat: bool = False, ): super().__init__() self._fpw_module = module self.is_flattened = False + self.optimize_backward_concat = optimize_backward_concat + # If optimize_backward_concat == True, used to propagate the + # corresponding FSDP modules's _require_backward_grad_sync flag + self._require_backward_grad_sync = True + # If optimize_backward_concat == True, used to accumulate the + # fp32 gradients for the flattened parameters + self.fp32_grads = [] # Handle param_list being None. if param_list is None: @@ -364,18 +378,60 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No delattr(self, n) self.flat_params = [] + # The post backward hook used to accumulate fp32 gradients + def _grad_accumulation_hook( + self, + grad, + param_index, + ): + if self.fp32_grads[param_index] is None: + self.fp32_grads[param_index] = grad.to(torch.float32) + else: + self.fp32_grads[param_index].add_(grad) + return grad + def _unflatten_params_as_views(self) -> None: """Unlike ``_unflatten_params``, this function unflatten into views and keep self.flat_param unchanged. """ assert self.is_flattened - ps = self.get_param_views() + if self.optimize_backward_concat: + # If self._require_backward_grad_sync == True (e.g. last microbatch), + # we use the original flat_params as autograd leaf nodes and backward + # pass should propagate all the way back to FSDP module and thus invoke + # FSDP post_backward() hook and concat() op + # Otherwise we stop the backward propagation before FSDP module to avoid + # invoking concat() and store the accumulated fp32 grads + if self._require_backward_grad_sync: + ps = self.get_param_views() + else: + with torch.no_grad(): + ps = self.get_param_views() + else: + ps = self.get_param_views() + param_views = [] for (_, m, n), p in zip(self._param_infos, ps): setattr(p, '_fsdp_weight', True) setattr(m, n, p) # This will set as plain attr + if self.optimize_backward_concat: + # The param_index of parameter p used to accumulate the correspnding + # gradients in self.fp32_grads + param_index = len(param_views) + # Register post backward hook to accumulate the gradients + # in self.fp32_grads + p.register_hook( + functools.partial( + self._grad_accumulation_hook, + param_index=param_index + ) + ) param_views.append(p) + if self.optimize_backward_concat and len(self.fp32_grads) == 0: + # Allocate self.fp32_grads at the beginning of each data batch's forward() + self.fp32_grads = [None] * len(param_views) + # Save param views for easy access if anyone still wants to access # parameters of the module. setattr(self._fpw_module, "_unflattened_param_views", param_views)