-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix]: let FSDP handle model with multiple forward pass and checkpoint #621
Merged
Merged
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
b095405
[fix]: let FSDP handle model with multiple forward pass and checkpoint
16bf375
Merge remote-tracking branch 'origin/master' into min/bug_617
927aedc
try CI again
bdcf21d
Merge remote-tracking branch 'origin/master' into min/bug_617
bace9bd
save
f095358
Merge remote-tracking branch 'origin/master' into min/bug_617
e020daa
save
73e85d1
fixed case with bn
23fcdc5
Merge remote-tracking branch 'origin/master' into min/bug_617
69f476c
minor
7c5c96f
add the new file
bf6c7e4
minor
1246147
Merge remote-tracking branch 'origin/master' into min/bug_617
3cd8260
Merge remote-tracking branch 'origin/master' into min/bug_617
8683f0b
added test of a single case, runtime is about 50s
9fd8b69
enable all 8 test cases
5e65399
cleanup
c04a066
cleanup
e6246cc
skip flatten case with 1.6 and 1.7
4e0951d
minor
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,9 +115,15 @@ class FullyShardedDataParallel(nn.Module): | |
an assert on the backward pass. The solution is to leave some parameters | ||
to the outer FSDP. | ||
|
||
.. warning:: | ||
|
||
If activation checkpointing is used with FSDP, it is strongly encouraged | ||
to use ``checkpoint_wrapper`` function from FairScale instead of the | ||
``checkpoint`` function from PyTorch. | ||
|
||
Args: | ||
module (nn.Module): | ||
module to checkpoint | ||
module to be wrapped with FullyShardedDataParallel. | ||
process_group (Optional): | ||
process group for sharding | ||
reshard_after_forward (bool, Optional): | ||
|
@@ -207,7 +213,7 @@ def __init__( | |
self.no_broadcast_optim_state = no_broadcast_optim_state | ||
self.state_dict_device = state_dict_device or self.compute_device | ||
|
||
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size) | ||
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size) | ||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor | ||
|
||
self.numel_padded_per_param: List[int] = [] | ||
|
@@ -275,11 +281,31 @@ def __init__( | |
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}" | ||
) | ||
|
||
def get_gradient_predivide_factor(self, world_size: int) -> int: | ||
factor = 1 | ||
# Flag to guard multiple pre-forward hook being executed per iteration. | ||
# This is reset at the end of the backward(). | ||
self._pre_backward_hook_has_run = False | ||
|
||
def _get_gradient_predivide_factor(self, world_size: int) -> float: | ||
factor: int = 1 | ||
while world_size % factor == 0 and world_size / factor > factor: | ||
factor = factor * 2 | ||
return factor | ||
factor *= 2 | ||
return float(factor) | ||
|
||
def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None: | ||
"""Allowing user to override the pre and post divide factors. | ||
|
||
Args: | ||
pre (float): divide factor before the reduction. | ||
post (float): divide factor after the reduction. | ||
recursive (bool): recursively set it for all child FSDP instances or not. | ||
""" | ||
self.assert_state(TrainingState.IDLE) | ||
if recursive: | ||
for module in self.modules(): | ||
if isinstance(module, FullyShardedDataParallel) and module != self: | ||
module.set_gradient_divide_factors(pre, post, False) | ||
self.gradient_predivide_factor = pre | ||
self.gradient_postdivide_factor = post | ||
|
||
@property | ||
def module(self) -> nn.Module: | ||
|
@@ -943,7 +969,13 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: | |
self._use_fp32_param_shard() | ||
|
||
# Register pre-backward hooks to all-gather the params for the backward | ||
# pass (if needed). | ||
# pass (if output's grad was needed). This won't register anything if | ||
# we are in eval mode. | ||
# | ||
# Some model does forward pass multiple times, we need to register the | ||
# pre-backward hook on every output since the last output's hook has to | ||
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run`` | ||
# to prevent repeated overhead from multiple hook callbacks. | ||
outputs = self._register_pre_backward_hooks(outputs) | ||
|
||
# Done with a forward pass. | ||
|
@@ -953,16 +985,18 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: | |
|
||
def _register_pre_backward_hooks(self, outputs: Any) -> Any: | ||
"""Register pre-backward hook to run before the wrapped module's | ||
backward. Hooks should be attached to all outputs from the forward.""" | ||
backward. Hooks should be attached to all outputs from the forward. | ||
|
||
Returns: | ||
outputs: new outputs with hooks registered if they requires gradient. | ||
""" | ||
if not torch.is_grad_enabled(): | ||
return outputs # don't register hooks if grad isn't enabled | ||
|
||
pre_backward_hook_has_run = [False] | ||
|
||
def _pre_backward_hook(*unused: Any) -> None: | ||
if pre_backward_hook_has_run[0]: | ||
return # only run once | ||
pre_backward_hook_has_run[0] = True | ||
if self._pre_backward_hook_has_run: | ||
return # only run once (from multiple outputs or multiple forward passes) | ||
self._pre_backward_hook_has_run = True | ||
|
||
# Start of a backward pass. | ||
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE]) | ||
|
@@ -1062,13 +1096,27 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
the local optimizer only sees the relevant parameter shard. | ||
""" | ||
# First hook callback will see PRE state. If we have multiple params, | ||
# then subsequent hook callbacks will see POST state. | ||
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) | ||
# then subsequent hook callbacks will see POST state. When checkpoint | ||
# fwd counter is used, IDLE is also possible since the pre-backward hook | ||
# is not triggered (see ``auto_wrap_bn`` below, we have to use | ||
# FSDP(checkpoint(conv, FSDP(bn), ...)), with reshard_after_forward=False). | ||
if hasattr(self, "_checkpoint_fwd_counter"): | ||
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST, TrainingState.IDLE]) | ||
else: | ||
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) | ||
self.training_state = TrainingState.BACKWARD_POST | ||
if param.grad is None: | ||
return | ||
|
||
if param.grad.requires_grad: | ||
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") | ||
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a thing? Which gradients need gradients!? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question! I think it came from @myleott originally. I would love to know more too. :-) |
||
|
||
# If this is a checkpointed module, we check if the following | ||
# counter reaches 0. If not, it is not the final backward call | ||
# for this module yet. Therefore, we early return in that case. | ||
if hasattr(self._fsdp_wrapped_module, "_checkpoint_fwd_counter"): | ||
if self._fsdp_wrapped_module._checkpoint_fwd_counter != 0: | ||
return | ||
|
||
if self._require_backward_grad_sync or self.reshard_after_forward: | ||
# Free full params. As a special case, we don't free the full params | ||
|
@@ -1200,6 +1248,7 @@ def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None: | |
for m in self.modules(): # includes self | ||
if isinstance(m, FullyShardedDataParallel): | ||
_remove_shard_bwd_hook(m) | ||
m._pre_backward_hook_has_run = False | ||
if m._has_params: | ||
if any(p.requires_grad for p in m.params): | ||
m.assert_state(TrainingState.BACKWARD_POST) | ||
|
@@ -1395,8 +1444,8 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None | |
# In case we are failing in the context of autograd hook, asserting | ||
# may not generate useful msg. So, let's print it to be sure. | ||
if self.rank == 0: | ||
print(self) | ||
print(msg) | ||
print(f"Asserting FSDP instance is: {self}") | ||
print(f"ERROR: {msg}") | ||
traceback.print_stack() | ||
raise ValueError(msg) | ||
|
||
|
@@ -1543,7 +1592,7 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) | |
v_shard = v[0] if self.rank >= len(v) else v[self.rank] | ||
assert ou.is_singleton_tensor(v_shard) | ||
else: | ||
v_shard = v # dont shard entries that are not tensors | ||
v_shard = v # don't shard entries that are not tensors | ||
full_optim_state_dict["state"][id][k] = v_shard | ||
|
||
return full_optim_state_dict | ||
|
@@ -1686,6 +1735,10 @@ def wrap_bn_only_policy(module: nn.Module, recurse: bool, unwrapped_params: int) | |
"process_group": pg, | ||
"mixed_precision": False, # Keep the weights in FP32. | ||
"flatten_parameters": False, # Do not flatten. | ||
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this | ||
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called | ||
# within the checkpoint's outer backward when multiple forward passes are used. | ||
"reshard_after_forward": False, | ||
} | ||
|
||
with enable_wrap(wrap_bn_only_policy, **fsdp_config): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
tests/nn/misc/test_flatten_params_wrapper.py | ||
tests/nn/data_parallel/test_fsdp.py | ||
tests/nn/data_parallel/test_fsdp_freezing_weights.py | ||
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py | ||
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py | ||
tests/nn/data_parallel/test_fsdp_freezing_weights.py | ||
tests/nn/data_parallel/test_fsdp.py | ||
tests/nn/misc/test_flatten_params_wrapper.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
originally, the variable is local to
_register_pre_backward_hooks
function and it only prevents callbacks registered within this function. When there are multiple calls to this function, this local flag won't prevent multiple callbacks. Moving this flag toself
solves this problem.